refactor: improve chat completion call handling

This commit is contained in:
Carl-Robert Linnupuu 2024-10-17 02:24:57 +03:00
parent 1b3b5687bc
commit 5ad9bcfaff
27 changed files with 568 additions and 610 deletions

View file

@ -28,7 +28,7 @@ import com.intellij.openapi.vfs.VirtualFile;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.Icons;
import ee.carlrobert.codegpt.completions.CommitMessageRequestParameters;
import ee.carlrobert.codegpt.completions.CommitMessageCompletionParameters;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.settings.configuration.CommitMessageTemplate;
import ee.carlrobert.codegpt.ui.OverlayUtil;
@ -96,7 +96,7 @@ public class GenerateGitCommitMessageAction extends AnAction {
if (editor != null) {
((EditorEx) editor).setCaretVisible(false);
CompletionRequestService.getInstance().getCommitMessageAsync(
new CommitMessageRequestParameters(
new CommitMessageCompletionParameters(
gitDiff,
project.getService(CommitMessageTemplate.class).getSystemPrompt()),
getEventListener(project, editor.getDocument()));

View file

@ -1,93 +0,0 @@
package ee.carlrobert.codegpt.completions;
import ee.carlrobert.codegpt.ReferencedFile;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.message.Message;
import java.util.List;
import java.util.UUID;
import org.jetbrains.annotations.Nullable;
public class CallParameters {
private final UUID sessionId;
private final Conversation conversation;
private final ConversationType conversationType;
private final Message message;
private final boolean retry;
private final String highlightedText;
private String imageMediaType;
private byte[] imageData;
private List<ReferencedFile> referencedFiles;
public CallParameters(Conversation conversation, Message message) {
this(null, conversation, message);
}
public CallParameters(UUID sessionId, Conversation conversation, Message message) {
this(sessionId, conversation, ConversationType.DEFAULT, message, null, false);
}
// TODO: Builder
public CallParameters(
UUID sessionId,
Conversation conversation,
ConversationType conversationType,
Message message,
@Nullable String highlightedText,
boolean retry) {
this.sessionId = sessionId;
this.conversation = conversation;
this.conversationType = conversationType;
this.message = message;
this.highlightedText = highlightedText;
this.retry = retry;
}
public UUID getSessionId() {
return sessionId;
}
public Conversation getConversation() {
return conversation;
}
public ConversationType getConversationType() {
return conversationType;
}
public Message getMessage() {
return message;
}
public boolean isRetry() {
return retry;
}
public @Nullable String getImageMediaType() {
return imageMediaType;
}
public void setImageMediaType(@Nullable String imageMediaType) {
this.imageMediaType = imageMediaType;
}
public byte[] getImageData() {
return imageData;
}
public void setImageData(byte[] imageData) {
this.imageData = imageData;
}
public @Nullable String getHighlightedText() {
return highlightedText;
}
public @Nullable List<ReferencedFile> getReferencedFiles() {
return referencedFiles;
}
public void setReferencedFiles(List<ReferencedFile> referencedFiles) {
this.referencedFiles = referencedFiles;
}
}

View file

@ -10,12 +10,12 @@ import okhttp3.sse.EventSource;
public class ChatCompletionEventListener implements CompletionEventListener<String> {
private final CallParameters callParameters;
private final ChatCompletionParameters callParameters;
private final CompletionResponseEventListener eventListener;
private final StringBuilder messageBuilder = new StringBuilder();
public ChatCompletionEventListener(
CallParameters callParameters,
ChatCompletionParameters callParameters,
CompletionResponseEventListener eventListener) {
this.callParameters = callParameters;
this.eventListener = eventListener;

View file

@ -69,7 +69,7 @@ public final class CompletionRequestService {
new OpenAIChatCompletionEventSourceListener(eventListener));
}
public String getLookupCompletion(LookupRequestCallParameters params) {
public String getLookupCompletion(LookupCompletionParameters params) {
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createLookupRequest(params);
@ -77,7 +77,7 @@ public final class CompletionRequestService {
}
public EventSource getCommitMessageAsync(
CommitMessageRequestParameters params,
CommitMessageCompletionParameters params,
CompletionEventListener<String> eventListener) {
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
@ -86,7 +86,7 @@ public final class CompletionRequestService {
}
public EventSource getEditCodeCompletionAsync(
EditCodeRequestParameters params,
EditCodeCompletionParameters params,
CompletionEventListener<String> eventListener) {
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())

View file

@ -19,7 +19,7 @@ public interface CompletionResponseEventListener {
default void handleCompleted(String fullMessage) {
}
default void handleCompleted(String fullMessage, CallParameters callParameters) {
default void handleCompleted(String fullMessage, ChatCompletionParameters callParameters) {
}
default void handleCodeGPTEvent(CodeGPTEvent event) {

View file

@ -57,7 +57,7 @@ public class MethodNameLookupListener implements LookupManagerListener {
String prompt) {
try {
var response = CompletionRequestService.getInstance()
.getLookupCompletion(new LookupRequestCallParameters(prompt));
.getLookupCompletion(new LookupCompletionParameters(prompt));
if (!response.isEmpty()) {
for (var value : response.split(",")) {
application.invokeLater(() -> application.runReadAction(() -> {

View file

@ -15,7 +15,7 @@ public class ToolwindowChatCompletionRequestHandler {
this.completionResponseEventListener = completionResponseEventListener;
}
public void call(CallParameters callParameters) {
public void call(ChatCompletionParameters callParameters) {
try {
eventSource = startCall(callParameters);
} catch (TotalUsageExceededException e) {
@ -33,11 +33,11 @@ public class ToolwindowChatCompletionRequestHandler {
}
}
private EventSource startCall(CallParameters callParameters) {
private EventSource startCall(ChatCompletionParameters callParameters) {
try {
var request = CompletionRequestFactory
.getFactory(GeneralSettings.getSelectedService())
.createChatRequest(new ChatCompletionRequestParameters(callParameters));
.createChatRequest(callParameters);
return CompletionRequestService.getInstance().getChatCompletionAsync(
request,
new ChatCompletionEventListener(callParameters, completionResponseEventListener));
@ -57,7 +57,7 @@ public class ToolwindowChatCompletionRequestHandler {
completionResponseEventListener.handleError(new ErrorDetails(errorMessage), ex);
}
private void sendInfo(CallParameters callParameters) {
private void sendInfo(ChatCompletionParameters callParameters) {
TelemetryAction.COMPLETION.createActionMessage()
.property("conversationId", callParameters.getConversation().getId().toString())
.property("model", callParameters.getConversation().getModel())

View file

@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.conversations;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import ee.carlrobert.codegpt.completions.CallParameters;
import ee.carlrobert.codegpt.completions.ChatCompletionParameters;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.service.ServiceType;
@ -63,11 +63,11 @@ public final class ConversationService {
conversationsMapping.put(conversation.getClientCode(), conversations);
}
public void saveMessage(String response, CallParameters callParameters) {
public void saveMessage(String response, ChatCompletionParameters callParameters) {
var conversation = callParameters.getConversation();
var message = callParameters.getMessage();
var conversationMessages = conversation.getMessages();
if (callParameters.isRetry() && !conversationMessages.isEmpty()) {
if (callParameters.getRetry() && !conversationMessages.isEmpty()) {
var messageToBeSaved = conversationMessages.stream()
.filter(item -> item.getId().equals(message.getId()))
.findFirst().orElseThrow();

View file

@ -2,12 +2,10 @@ package ee.carlrobert.codegpt.toolwindow.chat;
import static ee.carlrobert.codegpt.ui.UIUtil.createScrollPaneWithSmartScroller;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import com.intellij.openapi.Disposable;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.editor.SelectionModel;
import com.intellij.openapi.editor.ex.EditorEx;
import com.intellij.openapi.editor.impl.EditorImpl;
@ -17,7 +15,7 @@ import com.intellij.util.ui.JBUI;
import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.ReferencedFile;
import ee.carlrobert.codegpt.actions.ActionType;
import ee.carlrobert.codegpt.completions.CallParameters;
import ee.carlrobert.codegpt.completions.ChatCompletionParameters;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.completions.ConversationType;
import ee.carlrobert.codegpt.completions.ToolwindowChatCompletionRequestHandler;
@ -43,7 +41,6 @@ import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
@ -121,8 +118,28 @@ public class ChatToolWindowTabPanel implements Disposable {
totalTokensPanel.updateConversationTokens(conversation);
}
public void sendMessage(Message message) {
sendMessage(message, ConversationType.DEFAULT);
public List<ReferencedFile> getReferencedFiles() {
List<ReferencedFile> referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES);
if (referencedFiles == null) {
return conversation.getMessages().stream()
.flatMap(prevMessage -> {
if (prevMessage.getReferencedFilePaths() != null) {
return prevMessage.getReferencedFilePaths().stream();
}
return Stream.empty();
})
.map(filePath -> {
try {
return new ReferencedFile(new File(filePath));
} catch (Exception e) {
return null;
}
})
.filter(Objects::nonNull)
.toList();
}
return referencedFiles;
}
public void sendMessage(Message message, ConversationType conversationType) {
@ -134,79 +151,65 @@ public class ChatToolWindowTabPanel implements Disposable {
ConversationType conversationType,
@Nullable String highlightedText) {
ApplicationManager.getApplication().invokeLater(() -> {
var referencedFiles = project.getUserData(CodeGPTKeys.SELECTED_FILES);
var chatToolWindowPanel = project.getService(ChatToolWindowContentManager.class)
.tryFindChatToolWindowPanel();
if (referencedFiles != null && !referencedFiles.isEmpty()) {
var referencedFilePaths = referencedFiles.stream()
List<ReferencedFile> referencedFiles = getReferencedFiles();
if (!referencedFiles.isEmpty()) {
message.setReferencedFilePaths(referencedFiles.stream()
.map(ReferencedFile::getFilePath)
.toList();
message.setReferencedFilePaths(referencedFilePaths);
.toList());
message.setUserMessage(message.getPrompt());
chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project));
} else {
referencedFiles = conversation.getMessages().stream()
.flatMap(prevMessage -> {
if (prevMessage.getReferencedFilePaths() != null) {
return prevMessage.getReferencedFilePaths().stream();
}
return Stream.empty();
})
.map(filePath -> {
try {
return new ReferencedFile(new File(filePath));
} catch (Exception e) {
return null;
}
})
.filter(Objects::nonNull)
.toList();
}
String attachedImagePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project);
if (attachedImagePath != null) {
message.setImageFilePath(attachedImagePath);
}
totalTokensPanel.updateConversationTokens(conversation);
totalTokensPanel.updateReferencedFilesTokens(referencedFiles);
var userMessagePanel = new UserMessagePanel(project, message, this);
var attachedFilePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project);
var callParameters =
getCallParameters(conversationType, message, highlightedText, attachedFilePath);
callParameters.setReferencedFiles(referencedFiles);
if (callParameters.getImageData() != null) {
message.setImageFilePath(attachedFilePath);
chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project));
userMessagePanel.displayImage(attachedFilePath);
if (attachedImagePath != null || !referencedFiles.isEmpty()) {
project.getService(ChatToolWindowContentManager.class)
.tryFindChatToolWindowPanel()
.ifPresent(panel -> panel.clearNotifications(project));
}
var callParameters = getCallParameters(
message,
conversationType,
referencedFiles,
highlightedText,
attachedImagePath);
var responsePanel = createResponsePanel(callParameters);
var messagePanel = toolWindowScrollablePanel.addMessage(message.getId());
messagePanel.add(userMessagePanel);
var responsePanel = createResponsePanel(callParameters, conversationType);
messagePanel.add(new UserMessagePanel(project, message, this));
messagePanel.add(responsePanel);
call(callParameters, responsePanel);
});
}
private CallParameters getCallParameters(
ConversationType conversationType,
private ChatCompletionParameters getCallParameters(
Message message,
ConversationType conversationType,
List<ReferencedFile> referencedFiles,
@Nullable String highlightedText,
@Nullable String attachedFilePath) {
var callParameters = new CallParameters(
chatSession.getId(),
conversation,
conversationType,
message,
highlightedText,
false);
if (attachedFilePath != null && !attachedFilePath.isEmpty()) {
@Nullable String attachedImagePath) {
var builder = ChatCompletionParameters.builder(conversation, message)
.sessionId(chatSession.getId())
.conversationType(conversationType)
.highlightedText(highlightedText)
.referencedFiles(referencedFiles);
if (attachedImagePath != null && !attachedImagePath.isEmpty()) {
try {
callParameters.setImageData(Files.readAllBytes(Path.of(attachedFilePath)));
callParameters.setImageMediaType(FileUtil.getImageMediaType(attachedFilePath));
builder
.imageData(Files.readAllBytes(Path.of(attachedImagePath)))
.imageMediaType(FileUtil.getImageMediaType(attachedImagePath));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return callParameters;
return builder.build();
}
private boolean hasReferencedFilePaths(Message message) {
@ -219,15 +222,13 @@ public class ChatToolWindowTabPanel implements Disposable {
it -> it.getReferencedFilePaths() != null && !it.getReferencedFilePaths().isEmpty());
}
private ResponsePanel createResponsePanel(
CallParameters callParameters,
ConversationType conversationType) {
private ResponsePanel createResponsePanel(ChatCompletionParameters callParameters) {
var message = callParameters.getMessage();
var fileContextIncluded =
hasReferencedFilePaths(message) || hasReferencedFilePaths(conversation);
return new ResponsePanel()
.withReloadAction(() -> reloadMessage(message, conversation, conversationType))
.withReloadAction(() -> reloadMessage(callParameters))
.withDeleteAction(() -> removeMessage(message.getId(), conversation))
.addContent(
new ChatMessageResponseBody(
@ -241,31 +242,22 @@ public class ChatToolWindowTabPanel implements Disposable {
this));
}
private void reloadMessage(
Message message,
Conversation conversation,
ConversationType conversationType) {
private void reloadMessage(ChatCompletionParameters prevParameters) {
var prevMessage = prevParameters.getMessage();
ResponsePanel responsePanel = null;
try {
responsePanel = toolWindowScrollablePanel.getMessageResponsePanel(message.getId());
responsePanel = toolWindowScrollablePanel.getMessageResponsePanel(prevMessage.getId());
((ChatMessageResponseBody) responsePanel.getContent()).clear();
toolWindowScrollablePanel.update();
} catch (Exception e) {
throw new RuntimeException("Could not delete the existing message component", e);
} finally {
LOG.debug("Reloading message: " + message.getId());
LOG.debug("Reloading message: " + prevMessage.getId());
if (responsePanel != null) {
message.setResponse("");
conversationService.saveMessage(conversation, message);
call(new CallParameters(
chatSession.getId(),
conversation,
conversationType,
message,
null,
true),
responsePanel);
prevMessage.setResponse("");
conversationService.saveMessage(conversation, prevMessage);
call(prevParameters.toBuilder().retry(true).build(), responsePanel);
}
totalTokensPanel.updateConversationTokens(conversation);
@ -292,7 +284,7 @@ public class ChatToolWindowTabPanel implements Disposable {
totalTokensPanel.updateConversationTokens(conversation);
}
private void call(CallParameters callParameters, ResponsePanel responsePanel) {
private void call(ChatCompletionParameters callParameters, ResponsePanel responsePanel) {
var responseContainer = (ChatMessageResponseBody) responsePanel.getContent();
if (!CompletionRequestService.getInstance().isAllowed()) {
@ -316,25 +308,6 @@ public class ChatToolWindowTabPanel implements Disposable {
requestHandler.call(callParameters);
}
private String processEditorSelection(Editor editor, Message message) {
if (editor == null) {
return null;
}
SelectionModel selectionModel = editor.getSelectionModel();
String selectedText = selectionModel.getSelectedText();
if (selectedText == null || selectedText.isEmpty()) {
return null;
}
String fileExtension = FileUtil.getFileExtension(
((EditorEx) editor).getVirtualFile().getName());
message.setPrompt(
message.getPrompt() + String.format("%n```%s%n%s%n```", fileExtension, selectedText));
selectionModel.removeSelection();
return selectedText;
}
private Unit handleSubmit(String text, List<? extends AppliedActionInlay> appliedInlayActions) {
var message = new Message(text);
var editor = EditorUtil.getSelectedEditor(project);
@ -430,7 +403,10 @@ public class ChatToolWindowTabPanel implements Disposable {
var messagePanel = toolWindowScrollablePanel.addMessage(message.getId());
messagePanel.add(userMessagePanel);
messagePanel.add(new ResponsePanel()
.withReloadAction(() -> reloadMessage(message, conversation, ConversationType.DEFAULT))
.withReloadAction(() -> reloadMessage(
ChatCompletionParameters.builder(conversation, message)
.conversationType(ConversationType.DEFAULT)
.build()))
.withDeleteAction(() -> removeMessage(message.getId(), conversation))
.addContent(messageResponseBody));
});

View file

@ -5,7 +5,7 @@ import static com.intellij.openapi.ui.Messages.OK;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.diagnostic.Logger;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.completions.CallParameters;
import ee.carlrobert.codegpt.completions.ChatCompletionParameters;
import ee.carlrobert.codegpt.completions.CompletionResponseEventListener;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.ConversationService;
@ -53,16 +53,10 @@ abstract class ToolWindowCompletionResponseEventListener implements
@Override
public void handleMessage(String partialMessage) {
try {
responseContainer.update(partialMessage);
messageBuilder.append(partialMessage);
if (!completed) {
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
ApplicationManager.getApplication().invokeLater(() -> {
totalTokensPanel.update(
totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
});
}
var ongoingTokens = encodingManager.countTokens(messageBuilder.toString());
responseContainer.updateMessage(partialMessage);
totalTokensPanel.update(totalTokensPanel.getTokenDetails().getTotal() + ongoingTokens);
} catch (Exception e) {
responseContainer.displayError("Something went wrong.");
throw new RuntimeException("Error while updating the content", e);
@ -105,7 +99,7 @@ abstract class ToolWindowCompletionResponseEventListener implements
}
@Override
public void handleCompleted(String fullMessage, CallParameters callParameters) {
public void handleCompleted(String fullMessage, ChatCompletionParameters callParameters) {
conversationService.saveMessage(fullMessage, callParameters);
ApplicationManager.getApplication().invokeLater(() -> {

View file

@ -8,6 +8,7 @@ import static javax.swing.event.HyperlinkEvent.EventType.ACTIVATED;
import com.intellij.icons.AllIcons.General;
import com.intellij.openapi.Disposable;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.fileEditor.FileEditorManager;
import com.intellij.openapi.options.ShowSettingsUtil;
import com.intellij.openapi.project.Project;
@ -53,6 +54,8 @@ import org.jetbrains.annotations.Nullable;
public class ChatMessageResponseBody extends JPanel {
private static final Logger LOG = Logger.getInstance(ChatMessageResponseBody.class);
private final Project project;
private final Disposable parentDisposable;
private final StreamParser streamParser;
@ -123,14 +126,17 @@ public class ChatMessageResponseBody extends JPanel {
}
public ChatMessageResponseBody withResponse(String response) {
for (var message : MarkdownUtil.splitCodeBlocks(response)) {
processResponse(message, message.startsWith("```"), false);
try {
for (var message : MarkdownUtil.splitCodeBlocks(response)) {
processResponse(message, message.startsWith("```"), false);
}
} catch (Exception e) {
LOG.error("Something went wrong while processing input", e);
}
return this;
}
public void update(String partialMessage) {
public void updateMessage(String partialMessage) {
for (var item : streamParser.parse(partialMessage)) {
processResponse(item.response(), CODE.equals(item.type()), true);
}
@ -261,22 +267,24 @@ public class ChatMessageResponseBody extends JPanel {
var codeBlock = ((FencedCodeBlock) child);
var code = codeBlock.getContentChars().unescape();
if (!code.isEmpty()) {
if (currentlyProcessedEditorPanel == null) {
ApplicationManager.getApplication().invokeAndWait(() -> {
ApplicationManager.getApplication().invokeLater(() -> {
if (currentlyProcessedEditorPanel == null) {
prepareProcessingCode(code, codeBlock.getInfo().unescape());
});
}
EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code);
}
EditorUtil.updateEditorDocument(currentlyProcessedEditorPanel.getEditor(), code);
});
}
}
}
private void processText(String markdownText, boolean caretVisible) {
var html = convertMdToHtml(markdownText);
if (currentlyProcessedTextPane == null) {
prepareProcessingText(caretVisible);
}
currentlyProcessedTextPane.setText(html);
ApplicationManager.getApplication().invokeLater(() -> {
if (currentlyProcessedTextPane == null) {
prepareProcessingText(caretVisible);
}
currentlyProcessedTextPane.setText(html);
});
}
private void prepareProcessingText(boolean caretVisible) {

View file

@ -43,6 +43,10 @@ public class UserMessagePanel extends JPanel {
add(additionalContextPanel, BorderLayout.CENTER);
}
if (message.getImageFilePath() != null && !message.getImageFilePath().isEmpty()) {
displayImage(message.getImageFilePath());
}
var referencedFilePaths = message.getReferencedFilePaths();
if (referencedFilePaths != null && !referencedFilePaths.isEmpty()) {
add(createResponseBody(

View file

@ -10,7 +10,7 @@ import com.intellij.util.ui.AsyncProcessIcon
import com.intellij.openapi.util.text.StringUtil
import com.jetbrains.rd.util.AtomicReference
import ee.carlrobert.codegpt.completions.CompletionRequestService
import ee.carlrobert.codegpt.completions.EditCodeRequestParameters
import ee.carlrobert.codegpt.completions.EditCodeCompletionParameters
import ee.carlrobert.codegpt.ui.ObservableProperties
import javax.swing.JButton
@ -43,7 +43,7 @@ class EditCodeSubmissionHandler(
runInEdt { editor.selectionModel.removeSelection() }
service<CompletionRequestService>().getEditCodeCompletionAsync(
EditCodeRequestParameters(userPrompt, selectedText),
EditCodeCompletionParameters(userPrompt, selectedText),
EditCodeCompletionListener(
editor,
selectionTextRange,

View file

@ -1,19 +0,0 @@
package ee.carlrobert.codegpt.completions
interface CompletionCallParameters
data class ChatCompletionRequestParameters(
val callParameters: CallParameters
) : CompletionCallParameters
data class CommitMessageRequestParameters(
val gitDiff: String,
val systemPrompt: String
) : CompletionCallParameters
data class LookupRequestCallParameters(val prompt: String) : CompletionCallParameters
data class EditCodeRequestParameters(
val prompt: String,
val selectedText: String
) : CompletionCallParameters

View file

@ -0,0 +1,87 @@
package ee.carlrobert.codegpt.completions
import ee.carlrobert.codegpt.ReferencedFile
import ee.carlrobert.codegpt.conversations.Conversation
import ee.carlrobert.codegpt.conversations.message.Message
import java.util.*
interface CompletionParameters
class ChatCompletionParameters private constructor(
val conversation: Conversation,
val conversationType: ConversationType,
val message: Message,
var sessionId: UUID?,
var highlightedText: String?,
var retry: Boolean,
var imageMediaType: String?,
var imageData: ByteArray?,
var referencedFiles: List<ReferencedFile>?
) : CompletionParameters {
fun toBuilder(): Builder {
return Builder(conversation, message).apply {
sessionId(this@ChatCompletionParameters.sessionId)
conversationType(this@ChatCompletionParameters.conversationType)
highlightedText(this@ChatCompletionParameters.highlightedText)
retry(this@ChatCompletionParameters.retry)
imageMediaType(this@ChatCompletionParameters.imageMediaType)
imageData(this@ChatCompletionParameters.imageData)
referencedFiles(this@ChatCompletionParameters.referencedFiles)
}
}
class Builder(private val conversation: Conversation, private val message: Message) {
private var sessionId: UUID? = null
private var conversationType: ConversationType = ConversationType.DEFAULT
private var highlightedText: String? = null
private var retry: Boolean = false
private var imageMediaType: String? = null
private var imageData: ByteArray? = null
private var referencedFiles: List<ReferencedFile>? = null
fun sessionId(sessionId: UUID?) = apply { this.sessionId = sessionId }
fun conversationType(conversationType: ConversationType) =
apply { this.conversationType = conversationType }
fun highlightedText(highlightedText: String?) =
apply { this.highlightedText = highlightedText }
fun retry(retry: Boolean) = apply { this.retry = retry }
fun imageMediaType(imageMediaType: String?) = apply { this.imageMediaType = imageMediaType }
fun imageData(imageData: ByteArray?) = apply { this.imageData = imageData }
fun referencedFiles(referencedFiles: List<ReferencedFile>?) =
apply { this.referencedFiles = referencedFiles }
fun build(): ChatCompletionParameters {
return ChatCompletionParameters(
conversation,
conversationType,
message,
sessionId,
highlightedText,
retry,
imageMediaType,
imageData,
referencedFiles
)
}
}
companion object {
@JvmStatic
fun builder(conversation: Conversation, message: Message) = Builder(conversation, message)
}
}
data class CommitMessageCompletionParameters(
val gitDiff: String,
val systemPrompt: String
) : CompletionParameters
data class LookupCompletionParameters(val prompt: String) : CompletionParameters
data class EditCodeCompletionParameters(
val prompt: String,
val selectedText: String
) : CompletionParameters

View file

@ -7,10 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.llm.completion.CompletionRequest
interface CompletionRequestFactory {
fun createChatRequest(params: ChatCompletionRequestParameters): CompletionRequest
fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest
fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest
fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest
fun createChatRequest(params: ChatCompletionParameters): CompletionRequest
fun createEditCodeRequest(params: EditCodeCompletionParameters): CompletionRequest
fun createCommitMessageRequest(params: CommitMessageCompletionParameters): CompletionRequest
fun createLookupRequest(params: LookupCompletionParameters): CompletionRequest
companion object {
@JvmStatic
@ -30,16 +30,16 @@ interface CompletionRequestFactory {
}
abstract class BaseRequestFactory : CompletionRequestFactory {
override fun createEditCodeRequest(params: EditCodeRequestParameters): CompletionRequest {
override fun createEditCodeRequest(params: EditCodeCompletionParameters): CompletionRequest {
val prompt = "Code to modify:\n${params.selectedText}\n\nInstructions: ${params.prompt}"
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, 8192, true)
}
override fun createCommitMessageRequest(params: CommitMessageRequestParameters): CompletionRequest {
override fun createCommitMessageRequest(params: CommitMessageCompletionParameters): CompletionRequest {
return createBasicCompletionRequest(params.systemPrompt, params.gitDiff, 512, true)
}
override fun createLookupRequest(params: LookupRequestCallParameters): CompletionRequest {
override fun createLookupRequest(params: LookupCompletionParameters): CompletionRequest {
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, params.prompt, 512)
}
@ -50,7 +50,7 @@ abstract class BaseRequestFactory : CompletionRequestFactory {
stream: Boolean = false
): CompletionRequest
protected fun getPromptWithFilesContext(callParameters: CallParameters): String {
protected fun getPromptWithFilesContext(callParameters: ChatCompletionParameters): String {
return callParameters.referencedFiles?.let {
if (it.isEmpty()) {
callParameters.message.prompt

View file

@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest
@ -10,12 +10,11 @@ import ee.carlrobert.llm.completion.CompletionRequest
class AzureRequestFactory : BaseRequestFactory() {
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest {
val configuration = service<ConfigurationSettings>().state
val (callParameters) = params
val requestBuilder: OpenAIChatCompletionRequest.Builder =
OpenAIChatCompletionRequest.Builder(
buildOpenAIMessages(null, callParameters, callParameters.referencedFiles)
buildOpenAIMessages(null, params, params.referencedFiles)
)
.setMaxTokens(configuration.maxTokens)
.setStream(true)

View file

@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings
@ -11,15 +11,14 @@ import ee.carlrobert.llm.completion.CompletionRequest
class ClaudeRequestFactory : BaseRequestFactory() {
override fun createChatRequest(params: ChatCompletionRequestParameters): ClaudeCompletionRequest {
val (callParameters) = params
override fun createChatRequest(params: ChatCompletionParameters): ClaudeCompletionRequest {
return ClaudeCompletionRequest().apply {
model = service<AnthropicSettings>().state.model
maxTokens = service<ConfigurationSettings>().state.maxTokens
isStream = true
system = PersonaSettings.getSystemPrompt()
messages = callParameters.conversation.messages
messages = params.conversation.messages
.filter { it.response != null && it.response.isNotEmpty() }
.flatMap { prevMessage ->
sequenceOf(
@ -29,18 +28,15 @@ class ClaudeRequestFactory : BaseRequestFactory() {
}
when {
callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty() -> {
params.imageMediaType != null && params.imageData != null -> {
messages.add(
ClaudeCompletionDetailedMessage(
"user",
listOf(
ClaudeMessageImageContent(
ClaudeBase64Source(
callParameters.imageMediaType,
callParameters.imageData
)
ClaudeBase64Source(params.imageMediaType, params.imageData)
),
ClaudeMessageTextContent(callParameters.message.prompt)
ClaudeMessageTextContent(params.message.prompt)
)
)
)
@ -49,8 +45,7 @@ class ClaudeRequestFactory : BaseRequestFactory() {
else -> {
messages.add(
ClaudeCompletionStandardMessage(
"user",
getPromptWithFilesContext(callParameters)
"user", getPromptWithFilesContext(params)
)
)
}

View file

@ -4,7 +4,7 @@ import com.intellij.openapi.application.ApplicationInfo
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.CodeGPTPlugin
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
@ -13,14 +13,13 @@ import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionSt
class CodeGPTRequestFactory : BaseRequestFactory() {
override fun createChatRequest(params: ChatCompletionRequestParameters): ChatCompletionRequest {
val (callParameters) = params
override fun createChatRequest(params: ChatCompletionParameters): ChatCompletionRequest {
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
val configuration = service<ConfigurationSettings>().state
val requestBuilder: ChatCompletionRequest.Builder =
ChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters))
ChatCompletionRequest.Builder(buildOpenAIMessages(model, params))
.setModel(model)
.setSessionId(callParameters.sessionId)
.setSessionId(params.sessionId)
.setMetadata(
Metadata(
CodeGPTPlugin.getVersion(),
@ -40,16 +39,16 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
.setTemperature(configuration.temperature.toDouble())
}
if (callParameters.message.isWebSearchIncluded) {
if (params.message.isWebSearchIncluded) {
requestBuilder.setWebSearchIncluded(true)
}
val documentationDetails = callParameters.message.documentationDetails
val documentationDetails = params.message.documentationDetails
if (documentationDetails != null) {
requestBuilder.setDocumentationDetails(
DocumentationDetails(documentationDetails.name, documentationDetails.url)
)
}
callParameters.referencedFiles?.let {
params.referencedFiles?.let {
val fileContexts = it.map { file ->
ContextFile(file.fileName, file.fileContent)
}
@ -81,7 +80,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() {
.build()
}
fun buildBasicO1Request(
private fun buildBasicO1Request(
model: String,
prompt: String,
systemPrompt: String = "",

View file

@ -3,7 +3,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
@ -19,17 +19,12 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest
class CustomOpenAIRequestFactory : BaseRequestFactory() {
override fun createChatRequest(params: ChatCompletionRequestParameters): CustomOpenAIRequest {
val (callParameters) = params
override fun createChatRequest(params: ChatCompletionParameters): CustomOpenAIRequest {
val request = buildCustomOpenAIChatCompletionRequest(
service<CustomServiceSettings>()
.state
.chatCompletionSettings,
OpenAIRequestFactory.buildOpenAIMessages(
null,
callParameters,
callParameters.referencedFiles
),
OpenAIRequestFactory.buildOpenAIMessages(null, params, params.referencedFiles),
true,
getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY)
)

View file

@ -20,10 +20,9 @@ import java.nio.file.Path
class GoogleRequestFactory : BaseRequestFactory() {
override fun createChatRequest(params: ChatCompletionRequestParameters): GoogleCompletionRequest {
val (callParameters) = params
override fun createChatRequest(params: ChatCompletionParameters): GoogleCompletionRequest {
val configuration = service<ConfigurationSettings>().state
val messages = buildGoogleMessages(service<GoogleSettings>().state.model, callParameters)
val messages = buildGoogleMessages(service<GoogleSettings>().state.model, params)
return GoogleCompletionRequest.Builder(messages)
.generationConfig(
GoogleGenerationConfig.Builder()
@ -57,9 +56,9 @@ class GoogleRequestFactory : BaseRequestFactory() {
private fun buildGoogleMessages(
model: String?,
callParameters: CallParameters
params: ChatCompletionParameters
): List<GoogleCompletionContent> {
val messages = buildGoogleMessages(callParameters)
val messages = buildGoogleMessages(params)
if (model == null) {
return messages
@ -81,7 +80,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
} else {
tryReducingGoogleMessagesOrThrow(
messages,
callParameters.conversation.isDiscardTokenLimit,
params.conversation.isDiscardTokenLimit,
totalUsage,
googleModel.maxTokens
)
@ -89,11 +88,11 @@ class GoogleRequestFactory : BaseRequestFactory() {
} ?: messages
}
private fun buildGoogleMessages(callParameters: CallParameters): List<GoogleCompletionContent> {
val message = callParameters.message
private fun buildGoogleMessages(params: ChatCompletionParameters): List<GoogleCompletionContent> {
val message = params.message
val messages = mutableListOf<GoogleCompletionContent>()
when (callParameters.conversationType) {
when (params.conversationType) {
ConversationType.DEFAULT -> {
messages.add(
GoogleCompletionContent(
@ -114,8 +113,8 @@ class GoogleRequestFactory : BaseRequestFactory() {
else -> {}
}
for (prevMessage in callParameters.conversation.messages) {
if (callParameters.isRetry && prevMessage.id == message.id) {
for (prevMessage in params.conversation.messages) {
if (params.retry && prevMessage.id == message.id) {
break
}
@ -143,15 +142,15 @@ class GoogleRequestFactory : BaseRequestFactory() {
messages.add(GoogleCompletionContent("model", listOf(prevMessage.response)))
}
if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) {
if (params.imageMediaType != null && params.imageData != null) {
messages.add(
GoogleCompletionContent(
listOf(
GoogleContentPart(
null,
GoogleContentPart.Blob(
callParameters.imageMediaType,
callParameters.imageData
params.imageMediaType,
params.imageData
)
),
GoogleContentPart(message.prompt)
@ -162,7 +161,7 @@ class GoogleRequestFactory : BaseRequestFactory() {
messages.add(
GoogleCompletionContent(
"user",
listOf(getPromptWithFilesContext(callParameters))
listOf(getPromptWithFilesContext(params))
)
)
}

View file

@ -2,7 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.completions.llama.LlamaModel
@ -14,18 +14,17 @@ import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
class LlamaRequestFactory : BaseRequestFactory() {
override fun createChatRequest(params: ChatCompletionRequestParameters): LlamaCompletionRequest {
val (callParameters) = params
override fun createChatRequest(params: ChatCompletionParameters): LlamaCompletionRequest {
val promptTemplate = getPromptTemplate()
val systemPrompt =
if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS)
if (params.conversationType == ConversationType.FIX_COMPILE_ERRORS)
FIX_COMPILE_ERRORS_SYSTEM_PROMPT
else
getSystemPrompt()
val prompt = promptTemplate.buildPrompt(
systemPrompt,
getPromptWithFilesContext(callParameters),
callParameters.conversation.messages
getPromptWithFilesContext(params),
params.conversation.messages
)
return buildLlamaRequest(prompt, promptTemplate.stopTokens, true)

View file

@ -2,8 +2,7 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
@ -19,13 +18,12 @@ import java.util.*
class OllamaRequestFactory : BaseRequestFactory() {
override fun createChatRequest(params: ChatCompletionRequestParameters): OllamaChatCompletionRequest {
val (callParameters) = params
override fun createChatRequest(params: ChatCompletionParameters): OllamaChatCompletionRequest {
val configuration = service<ConfigurationSettings>().state
val settings = service<OllamaSettings>().state
return OllamaChatCompletionRequest.Builder(
settings.model,
buildOllamaMessages(callParameters)
buildOllamaMessages(params)
)
.setStream(true)
.setOptions(
@ -54,11 +52,11 @@ class OllamaRequestFactory : BaseRequestFactory() {
.build()
}
private fun buildOllamaMessages(callParameters: CallParameters): List<OllamaChatCompletionMessage> {
val message = callParameters.message
private fun buildOllamaMessages(params: ChatCompletionParameters): List<OllamaChatCompletionMessage> {
val message = params.message
val messages = mutableListOf<OllamaChatCompletionMessage>()
when (callParameters.conversationType) {
when (params.conversationType) {
ConversationType.DEFAULT -> messages.add(
OllamaChatCompletionMessage("system", PersonaSettings.getSystemPrompt(), null)
)
@ -70,8 +68,8 @@ class OllamaRequestFactory : BaseRequestFactory() {
else -> {}
}
for (prevMessage in callParameters.conversation.messages) {
if (callParameters.isRetry && prevMessage.id == message.id) break
for (prevMessage in params.conversation.messages) {
if (params.retry && prevMessage.id == message.id) break
prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath ->
try {
@ -91,7 +89,7 @@ class OllamaRequestFactory : BaseRequestFactory() {
messages.add(
OllamaChatCompletionMessage(
"user",
getPromptWithFilesContext(callParameters),
getPromptWithFilesContext(params),
null
)
)
@ -100,8 +98,8 @@ class OllamaRequestFactory : BaseRequestFactory() {
messages.add(OllamaChatCompletionMessage("assistant", prevMessage.response, null))
}
if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) {
val imageBase64 = Base64.getEncoder().encodeToString(callParameters.imageData)
if (params.imageMediaType != null && params.imageData != null) {
val imageBase64 = Base64.getEncoder().encodeToString(params.imageData)
messages.add(OllamaChatCompletionMessage("user", message.prompt, listOf(imageBase64)))
} else {
messages.add(OllamaChatCompletionMessage("user", message.prompt, null))

View file

@ -21,13 +21,12 @@ import java.nio.file.Path
class OpenAIRequestFactory : CompletionRequestFactory {
override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest {
val (callParameters) = params
override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest {
val model = service<OpenAISettings>().state.model
val configuration = service<ConfigurationSettings>().state
val requestBuilder: OpenAIChatCompletionRequest.Builder =
OpenAIChatCompletionRequest.Builder(
buildOpenAIMessages(model, callParameters, callParameters.referencedFiles)
buildOpenAIMessages(model, params, params.referencedFiles)
)
.setModel(model)
if ("o1-mini" == model || "o1-preview" == model) {
@ -48,7 +47,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
return requestBuilder.build()
}
override fun createEditCodeRequest(params: EditCodeRequestParameters): OpenAIChatCompletionRequest {
override fun createEditCodeRequest(params: EditCodeCompletionParameters): OpenAIChatCompletionRequest {
val model = service<OpenAISettings>().state.model
val prompt = "Code to modify:\n${params.selectedText}\n\nInstructions: ${params.prompt}"
if (model == "o1-mini" || model == "o1-preview") {
@ -57,7 +56,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, prompt, model, true)
}
override fun createCommitMessageRequest(params: CommitMessageRequestParameters): OpenAIChatCompletionRequest {
override fun createCommitMessageRequest(params: CommitMessageCompletionParameters): OpenAIChatCompletionRequest {
val model = service<OpenAISettings>().state.model
val (gitDiff, systemPrompt) = params
if (model == "o1-mini" || model == "o1-preview") {
@ -66,7 +65,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
return createBasicCompletionRequest(systemPrompt, gitDiff, model, true)
}
override fun createLookupRequest(params: LookupRequestCallParameters): OpenAIChatCompletionRequest {
override fun createLookupRequest(params: LookupCompletionParameters): OpenAIChatCompletionRequest {
val model = service<OpenAISettings>().state.model
val (prompt) = params
if (model == "o1-mini" || model == "o1-preview") {
@ -103,7 +102,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
fun buildOpenAIMessages(
model: String?,
callParameters: CallParameters,
callParameters: ChatCompletionParameters,
referencedFiles: List<ReferencedFile>? = mutableListOf()
): List<OpenAIChatCompletionMessage> {
val messages = buildOpenAIChatMessages(model, callParameters, referencedFiles)
@ -140,7 +139,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
private fun buildOpenAIChatMessages(
model: String?,
callParameters: CallParameters,
callParameters: ChatCompletionParameters,
referencedFiles: List<ReferencedFile>? = mutableListOf()
): MutableList<OpenAIChatCompletionMessage> {
val message = callParameters.message
@ -169,7 +168,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
}
for (prevMessage in callParameters.conversation.messages) {
if (callParameters.isRetry && prevMessage.id == message.id) {
if (callParameters.retry && prevMessage.id == message.id) {
break
}
val prevMessageImageFilePath = prevMessage.imageFilePath
@ -203,7 +202,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
)
}
if (callParameters.imageMediaType != null && callParameters.imageData.isNotEmpty()) {
if (callParameters.imageMediaType != null && callParameters.imageData != null) {
messages.add(
OpenAIChatCompletionDetailedMessage(
"user",

View file

@ -22,12 +22,11 @@ class CompletionRequestProviderTest : IntegrationTest() {
val secondMessage = createDummyMessage(250)
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val callParameters = ChatCompletionParameters
.builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
.build()
val request = OpenAIRequestFactory().createChatRequest(
ChatCompletionRequestParameters(
CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
)
)
val request = OpenAIRequestFactory().createChatRequest(callParameters)
assertThat(request.messages)
.extracting("role", "content")
@ -49,12 +48,11 @@ class CompletionRequestProviderTest : IntegrationTest() {
val secondMessage = createDummyMessage(250)
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val callParameters = ChatCompletionParameters
.builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
.build()
val request = OpenAIRequestFactory().createChatRequest(
ChatCompletionRequestParameters(
CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
)
)
val request = OpenAIRequestFactory().createChatRequest(callParameters)
assertThat(request.messages)
.extracting("role", "content")
@ -76,19 +74,11 @@ class CompletionRequestProviderTest : IntegrationTest() {
val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250)
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val callParameters = ChatCompletionParameters.builder(conversation, secondMessage)
.retry(true)
.build()
val request = OpenAIRequestFactory().createChatRequest(
ChatCompletionRequestParameters(
CallParameters(
null,
conversation,
ConversationType.DEFAULT,
secondMessage,
null,
true
)
)
)
val request = OpenAIRequestFactory().createChatRequest(callParameters)
assertThat(request.messages)
.extracting("role", "content")
@ -111,12 +101,11 @@ class CompletionRequestProviderTest : IntegrationTest() {
val remainingMessage = createDummyMessage(1000)
conversation.addMessage(remainingMessage)
conversation.discardTokenLimits()
val callParameters = ChatCompletionParameters
.builder(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
.build()
val request = OpenAIRequestFactory().createChatRequest(
ChatCompletionRequestParameters(
CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT"))
)
)
val request = OpenAIRequestFactory().createChatRequest(callParameters)
assertThat(request.messages)
.extracting("role", "content")
@ -137,9 +126,9 @@ class CompletionRequestProviderTest : IntegrationTest() {
assertThrows(TotalUsageExceededException::class.java) {
OpenAIRequestFactory().createChatRequest(
ChatCompletionRequestParameters(
CallParameters(conversation, createDummyMessage(100))
)
ChatCompletionParameters
.builder(conversation, createDummyMessage(100))
.build()
)
}
}

View file

@ -16,246 +16,275 @@ import testsupport.IntegrationTest
class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
fun testOpenAIChatCompletionCall() {
useOpenAIService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
fun testOpenAIChatCompletionCall() {
useOpenAIService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages"
)
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")
)
)
listOf(
jsonMapResponse(
"choices",
jsonArray(jsonMap("delta", jsonMap("role", "assistant")))
),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))
)
})
val requestHandler =
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
requestHandler.call(CallParameters(conversation, message))
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
waitExpecting { "Hello!" == message.response }
}
fun testAzureChatCompletionCall() {
useAzureService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val conversationService = ConversationService.getInstance()
val prevMessage = Message("TEST_PREV_PROMPT")
prevMessage.response = "TEST_PREV_RESPONSE"
val conversation = conversationService.startConversation()
conversation.addMessage(prevMessage)
conversationService.saveConversation(conversation)
expectAzure(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo(
"/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions")
assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION")
assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY")
assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt")
assertThat(request.body)
.extracting("messages")
.isEqualTo(
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"),
mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"),
mapOf("role" to "user", "content" to "TEST_PROMPT")))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
val message = Message("TEST_PROMPT")
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
requestHandler.call(CallParameters(conversation, message))
waitExpecting { "Hello!" == message.response }
}
fun testLlamaChatCompletionCall() {
useLlamaService()
service<ConfigurationSettings>().state.maxTokens = 99
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(Message("Ping", "Pong"))
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectLlama(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/completion")
assertThat(request.body)
.extracting(
"prompt",
"n_predict",
"stream")
.containsExactly(
LLAMA.buildPrompt(
"TEST_SYSTEM_PROMPT",
"TEST_PROMPT",
conversation.messages),
99,
true)
listOf<String?>(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)))
})
requestHandler.call(CallParameters(conversation, message))
waitExpecting { "Hello!" == message.response }
}
fun testOllamaChatCompletionCall() {
useOllamaService()
service<ConfigurationSettings>().state.maxTokens = 99
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/api/chat")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages",
"options.num_predict",
"stream"
)
.containsExactly(
HuggingFaceModel.LLAMA_3_8B_Q6_K.code,
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")
),
99,
true
)
listOf(
jsonMapResponse(
e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))),
e("done", false)),
jsonMapResponse(
e("message", jsonMap(e("content", "lo"), e("role", "assistant"))),
e("done", false)),
jsonMapResponse(
e("message", jsonMap(e("content", "!"), e("role", "assistant"))),
e("done", false)),
jsonMapResponse(
e("message", jsonMap(e("content", ""), e("role", "assistant"))),
e("done", true))
)
})
requestHandler.call(CallParameters(conversation, message))
waitExpecting { "Hello!" == message.response }
}
fun testGoogleChatCompletionCall() {
useGoogleService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectGoogle(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent")
assertThat(request.method).isEqualTo("POST")
assertThat(request.uri.query).isEqualTo("key=TEST_API_KEY&alt=sse")
assertThat(request.body)
.extracting("contents")
.isEqualTo(
listOf(
mapOf("parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")), "role" to "user"),
mapOf("parts" to listOf(mapOf("text" to "Understood.")), "role" to "model"),
mapOf("parts" to listOf(mapOf("text" to "TEST_PROMPT")), "role" to "user"),
)
)
listOf(
jsonMapResponse(
"candidates",
jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "Hello")))))
),
jsonMapResponse(
"candidates",
jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "!")))))
)
)
})
requestHandler.call(CallParameters(conversation, message))
waitExpecting { "Hello!" == message.response }
}
fun testCodeGPTServiceChatCompletionCall() {
useCodeGPTService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler =
ToolwindowChatCompletionRequestHandler(
getRequestEventListener(message)
)
expectCodeGPT(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages")
.containsExactly(
"TEST_MODEL",
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
requestHandler.call(CallParameters(conversation, message))
waitExpecting { "Hello!" == message.response }
}
private fun getRequestEventListener(message: Message): CompletionResponseEventListener {
return object : CompletionResponseEventListener {
override fun handleCompleted(fullMessage: String, callParameters: CallParameters) {
message.response = fullMessage
}
waitExpecting { "Hello!" == message.response }
}
fun testAzureChatCompletionCall() {
useAzureService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val conversationService = ConversationService.getInstance()
val prevMessage = Message("TEST_PREV_PROMPT")
prevMessage.response = "TEST_PREV_RESPONSE"
val conversation = conversationService.startConversation()
conversation.addMessage(prevMessage)
conversationService.saveConversation(conversation)
expectAzure(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo(
"/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions"
)
assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION")
assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY")
assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt")
assertThat(request.body)
.extracting("messages")
.isEqualTo(
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"),
mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"),
mapOf("role" to "user", "content" to "TEST_PROMPT")
)
)
listOf(
jsonMapResponse(
"choices",
jsonArray(jsonMap("delta", jsonMap("role", "assistant")))
),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))
)
})
val message = Message("TEST_PROMPT")
val requestHandler =
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
waitExpecting { "Hello!" == message.response }
}
fun testLlamaChatCompletionCall() {
useLlamaService()
service<ConfigurationSettings>().state.maxTokens = 99
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(Message("Ping", "Pong"))
expectLlama(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/completion")
assertThat(request.body)
.extracting(
"prompt",
"n_predict",
"stream"
)
.containsExactly(
LLAMA.buildPrompt(
"TEST_SYSTEM_PROMPT",
"TEST_PROMPT",
conversation.messages
),
99,
true
)
listOf<String?>(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)
)
)
})
val requestHandler =
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
waitExpecting { "Hello!" == message.response }
}
fun testOllamaChatCompletionCall() {
useOllamaService()
service<ConfigurationSettings>().state.maxTokens = 99
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/api/chat")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages",
"options.num_predict",
"stream"
)
.containsExactly(
HuggingFaceModel.LLAMA_3_8B_Q6_K.code,
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")
),
99,
true
)
listOf(
jsonMapResponse(
e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))),
e("done", false)
),
jsonMapResponse(
e("message", jsonMap(e("content", "lo"), e("role", "assistant"))),
e("done", false)
),
jsonMapResponse(
e("message", jsonMap(e("content", "!"), e("role", "assistant"))),
e("done", false)
),
jsonMapResponse(
e("message", jsonMap(e("content", ""), e("role", "assistant"))),
e("done", true)
)
)
})
val requestHandler =
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
waitExpecting { "Hello!" == message.response }
}
fun testGoogleChatCompletionCall() {
useGoogleService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
expectGoogle(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/models/gemini-pro:streamGenerateContent")
assertThat(request.method).isEqualTo("POST")
assertThat(request.uri.query).isEqualTo("key=TEST_API_KEY&alt=sse")
assertThat(request.body)
.extracting("contents")
.isEqualTo(
listOf(
mapOf(
"parts" to listOf(mapOf("text" to "TEST_SYSTEM_PROMPT")),
"role" to "user"
),
mapOf("parts" to listOf(mapOf("text" to "Understood.")), "role" to "model"),
mapOf("parts" to listOf(mapOf("text" to "TEST_PROMPT")), "role" to "user"),
)
)
listOf(
jsonMapResponse(
"candidates",
jsonArray(
jsonMap(
"content",
jsonMap("parts", jsonArray(jsonMap("text", "Hello")))
)
)
),
jsonMapResponse(
"candidates",
jsonArray(jsonMap("content", jsonMap("parts", jsonArray(jsonMap("text", "!")))))
)
)
})
val requestHandler =
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
waitExpecting { "Hello!" == message.response }
}
fun testCodeGPTServiceChatCompletionCall() {
useCodeGPTService()
service<PersonaSettings>().state.selectedPersona.instructions = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
expectCodeGPT(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages"
)
.containsExactly(
"TEST_MODEL",
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")
)
)
listOf(
jsonMapResponse(
"choices",
jsonArray(jsonMap("delta", jsonMap("role", "assistant")))
),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))
)
})
val requestHandler =
ToolwindowChatCompletionRequestHandler(getRequestEventListener(message))
requestHandler.call(ChatCompletionParameters.builder(conversation, message).build())
waitExpecting { "Hello!" == message.response }
}
private fun getRequestEventListener(message: Message): CompletionResponseEventListener {
return object : CompletionResponseEventListener {
override fun handleCompleted(
fullMessage: String,
callParameters: ChatCompletionParameters
) {
message.response = fullMessage
}
}
}
}
}

View file

@ -59,7 +59,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
)
})
panel.sendMessage(message)
panel.sendMessage(message, ConversationType.DEFAULT)
waitExpecting {
val messages = conversation.messages
@ -161,7 +161,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
)
})
panel.sendMessage(message)
panel.sendMessage(message, ConversationType.DEFAULT)
waitExpecting {
val messages = conversation.messages
@ -250,7 +250,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
)
})
panel.sendMessage(message)
panel.sendMessage(message, ConversationType.DEFAULT)
waitExpecting {
val messages = conversation.messages