From 30c255c5b5341f4d37bfa9c76f3898b62ecba272 Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Tue, 8 Oct 2024 00:33:33 +0300 Subject: [PATCH] feat: support high context limits (CodeGPT) --- gradle/libs.versions.toml | 2 +- .../carlrobert/codegpt/EncodingManager.java | 2 +- .../codegpt/completions/CallParameters.java | 13 +++- .../chat/ChatToolWindowTabPanel.java | 21 +++++-- .../chat/ui/ChatMessageResponseBody.java | 60 +++++++++++++++++-- .../ui/ChatToolWindowScrollablePanel.java | 2 +- .../toolwindow/chat/ui/UserMessagePanel.java | 1 + .../chat/ui/textarea/TotalTokensPanel.java | 14 +++-- .../ui/checkbox/VirtualFileCheckboxTree.java | 14 ++++- .../completions/CompletionRequestFactory.kt | 13 ++++ .../factory/AzureRequestFactory.kt | 5 +- .../factory/ClaudeRequestFactory.kt | 5 +- .../factory/CodeGPTRequestFactory.kt | 10 ++++ .../factory/CustomOpenAIRequestFactory.kt | 6 +- .../factory/GoogleRequestFactory.kt | 7 ++- .../factory/LlamaRequestFactory.kt | 2 +- .../factory/OllamaRequestFactory.kt | 10 +++- .../factory/OpenAIRequestFactory.kt | 21 +++++-- .../carlrobert/codegpt/events/CodeGPTEvent.kt | 12 ++++ .../carlrobert/codegpt/util/file/FileUtil.kt | 4 +- 20 files changed, 190 insertions(+), 34 deletions(-) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index b0102be8..c0f2c7d5 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -12,7 +12,7 @@ jsoup = "1.17.2" jtokkit = "1.1.0" junit = "5.11.0" kotlin = "2.0.0" -llm-client = "0.8.19" +llm-client = "0.8.21" okio = "3.9.0" tree-sitter = "0.22.6a" diff --git a/src/main/java/ee/carlrobert/codegpt/EncodingManager.java b/src/main/java/ee/carlrobert/codegpt/EncodingManager.java index c3b4d44f..c994f3ea 100644 --- a/src/main/java/ee/carlrobert/codegpt/EncodingManager.java +++ b/src/main/java/ee/carlrobert/codegpt/EncodingManager.java @@ -63,7 +63,7 @@ public final class EncodingManager { public int countTokens(String text) { try { // #444: Cl100kParser.split() throws AssertionError "Input is not UTF-8: " - return encoding.countTokens(text); + return encoding.countTokens(text.replaceAll("<|", "").replaceAll("|>", "")); } catch (Exception | Error ex) { LOG.warn("Could not count tokens for: " + text, ex); return 0; diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java b/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java index a85f4699..877563ed 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java @@ -1,7 +1,9 @@ package ee.carlrobert.codegpt.completions; +import ee.carlrobert.codegpt.ReferencedFile; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.message.Message; +import java.util.List; import org.jetbrains.annotations.Nullable; public class CallParameters { @@ -11,8 +13,9 @@ public class CallParameters { private final Message message; private final boolean retry; private final String highlightedText; - private @Nullable String imageMediaType; + private String imageMediaType; private byte[] imageData; + private List referencedFiles; public CallParameters(Conversation conversation, Message message) { this(conversation, ConversationType.DEFAULT, message, null, false); @@ -66,4 +69,12 @@ public class CallParameters { public @Nullable String getHighlightedText() { return highlightedText; } + + public @Nullable List getReferencedFiles() { + return referencedFiles; + } + + public void setReferencedFiles(List referencedFiles) { + this.referencedFiles = referencedFiles; + } } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java index 964d6c33..e75d6c59 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java @@ -16,7 +16,6 @@ import ee.carlrobert.codegpt.ReferencedFile; import ee.carlrobert.codegpt.actions.ActionType; import ee.carlrobert.codegpt.completions.CallParameters; import ee.carlrobert.codegpt.completions.CompletionRequestService; -import ee.carlrobert.codegpt.completions.CompletionRequestUtil; import ee.carlrobert.codegpt.completions.ConversationType; import ee.carlrobert.codegpt.completions.ToolwindowChatCompletionRequestHandler; import ee.carlrobert.codegpt.conversations.Conversation; @@ -134,8 +133,6 @@ public class ChatToolWindowTabPanel implements Disposable { .toList(); message.setReferencedFilePaths(referencedFilePaths); message.setUserMessage(message.getPrompt()); - message.setPrompt( - CompletionRequestUtil.getPromptWithContext(referencedFiles, message.getPrompt())); totalTokensPanel.updateReferencedFilesTokens(referencedFiles); @@ -147,6 +144,7 @@ public class ChatToolWindowTabPanel implements Disposable { var attachedFilePath = CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH.get(project); var callParameters = getCallParameters(conversationType, message, highlightedText, attachedFilePath); + callParameters.setReferencedFiles(referencedFiles); if (callParameters.getImageData() != null) { message.setImageFilePath(attachedFilePath); chatToolWindowPanel.ifPresent(panel -> panel.clearNotifications(project)); @@ -180,10 +178,23 @@ public class ChatToolWindowTabPanel implements Disposable { return callParameters; } + private boolean hasReferencedFilePaths(Message message) { + return message.getReferencedFilePaths() != null && !message.getReferencedFilePaths().isEmpty(); + } + + private boolean hasReferencedFilePaths(Conversation conversation) { + return conversation.getMessages().stream() + .anyMatch( + it -> it.getReferencedFilePaths() != null && !it.getReferencedFilePaths().isEmpty()); + } + private ResponsePanel createResponsePanel( CallParameters callParameters, ConversationType conversationType) { var message = callParameters.getMessage(); + var fileContextIncluded = + hasReferencedFilePaths(message) || hasReferencedFilePaths(conversation); + return new ResponsePanel() .withReloadAction(() -> reloadMessage(message, conversation, conversationType)) .withDeleteAction(() -> removeMessage(message.getId(), conversation)) @@ -194,7 +205,9 @@ public class ChatToolWindowTabPanel implements Disposable { true, false, message.isWebSearchIncluded(), - message.getDocumentationDetails() != null, this)); + message.getDocumentationDetails() != null, + fileContextIncluded, + this)); } private void reloadMessage( diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java index 604501b8..708cfb03 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatMessageResponseBody.java @@ -27,6 +27,7 @@ import ee.carlrobert.codegpt.events.AnalysisCompletedEventDetails; import ee.carlrobert.codegpt.events.AnalysisFailedEventDetails; import ee.carlrobert.codegpt.events.CodeGPTEvent; import ee.carlrobert.codegpt.events.EventDetails; +import ee.carlrobert.codegpt.events.ProcessContextEventDetails; import ee.carlrobert.codegpt.events.WebSearchEventDetails; import ee.carlrobert.codegpt.settings.GeneralSettingsConfigurable; import ee.carlrobert.codegpt.telemetry.TelemetryAction; @@ -44,10 +45,10 @@ import javax.swing.Box; import javax.swing.BoxLayout; import javax.swing.DefaultListModel; import javax.swing.Icon; +import javax.swing.JComponent; import javax.swing.JPanel; import javax.swing.JTextPane; import javax.swing.SwingConstants; -import javax.swing.SwingUtilities; import org.jetbrains.annotations.Nullable; public class ChatMessageResponseBody extends JPanel { @@ -59,7 +60,9 @@ public class ChatMessageResponseBody extends JPanel { private final DefaultListModel webpageListModel = new DefaultListModel<>(); private final WebpageList webpageList = new WebpageList(webpageListModel); private final JPanel webDocProgressContainer = new JPanel(); - private final AsyncProcessIcon spinner = new AsyncProcessIcon("sign_in_spinner"); + private final JPanel progressContainer = new JPanel(); + private final AsyncProcessIcon webDocsSpinner = new AsyncProcessIcon("web_docs_spinner"); + private final AsyncProcessIcon processSpinner = new AsyncProcessIcon("process_spinner"); private final @Nullable String highlightedText; private ResponseEditorPanel currentlyProcessedEditorPanel; private JTextPane currentlyProcessedTextPane; @@ -67,7 +70,7 @@ public class ChatMessageResponseBody extends JPanel { private boolean responseReceived; public ChatMessageResponseBody(Project project, Disposable parentDisposable) { - this(project, null, false, false, false, false, parentDisposable); + this(project, null, false, false, false, false, false, parentDisposable); } public ChatMessageResponseBody( @@ -77,8 +80,8 @@ public class ChatMessageResponseBody extends JPanel { boolean readOnly, boolean webSearchIncluded, boolean webDocIncluded, + boolean fileContextIncluded, Disposable parentDisposable) { - super(new BorderLayout()); this.project = project; this.highlightedText = highlightedText; this.parentDisposable = parentDisposable; @@ -98,6 +101,12 @@ public class ChatMessageResponseBody extends JPanel { add(webDocProgressContainer); } + if (fileContextIncluded) { + progressContainer.setLayout(new BoxLayout(progressContainer, BoxLayout.Y_AXIS)); + progressContainer.setBorder(JBUI.Borders.emptyBottom(8)); + add(progressContainer); + } + if (withGhostText) { prepareProcessingText(!readOnly); currentlyProcessedTextPane.setText( @@ -209,6 +218,7 @@ public class ChatMessageResponseBody extends JPanel { case ANALYZE_WEB_DOC_STARTED -> showWebDocsProgress(); case ANALYZE_WEB_DOC_COMPLETED -> completeWebDocsProgress(event.getDetails()); case ANALYZE_WEB_DOC_FAILED -> failWebDocsProgress(event.getDetails()); + case PROCESS_CONTEXT -> showProcessContextEvent(event.getDetails()); default -> { } } @@ -305,7 +315,7 @@ public class ChatMessageResponseBody extends JPanel { private void showWebDocsProgress() { var wrapper = new JPanel(new FlowLayout(FlowLayout.LEADING, 0, 0)); - wrapper.add(spinner); + wrapper.add(webDocsSpinner); wrapper.add(Box.createHorizontalStrut(4)); wrapper.add(new JBLabel( CodeGPTBundle.get("chatMessageResponseBody.webDocs.startProgress.label")).withFont( @@ -325,10 +335,50 @@ public class ChatMessageResponseBody extends JPanel { } } + private void showProcessContextEvent(EventDetails eventDetails) { + if (eventDetails instanceof ProcessContextEventDetails details) { + switch (details.getStatus()) { + case "STARTED": { + updateProgressContainer(details.getDescription(), null); + break; + } + case "FAILED": { + updateProgressContainer(details.getDescription(), General.Error); + break; + } + case "COMPLETED": { + updateProgressContainer(details.getDescription(), Icons.GreenCheckmark); + break; + } + default: + break; + } + } + } + private void updateWebDocsProgressLabel(String text, Icon icon) { updateWebDocsProgress(new JBLabel(text, icon, SwingConstants.LEADING).withFont(JBFont.small())); } + private void updateProgressContainer(String text, @Nullable Icon icon) { + ApplicationManager.getApplication().invokeLater(() -> { + progressContainer.removeAll(); + JComponent wrapper; + if (icon != null) { + wrapper = new JBLabel(text, icon, SwingConstants.LEADING); + ((JBLabel) wrapper).setHorizontalTextPosition(SwingConstants.LEADING); + } else { + wrapper = new JPanel(new FlowLayout(FlowLayout.LEADING, 0, 0)); + wrapper.add(new JBLabel(text)); + wrapper.add(Box.createHorizontalStrut(4)); + wrapper.add(processSpinner); + } + progressContainer.add(JBUI.Panels.simplePanel(wrapper)); + progressContainer.revalidate(); + progressContainer.repaint(); + }); + } + private void updateWebDocsProgress(Component content) { webDocProgressContainer.removeAll(); webDocProgressContainer.add(JBUI.Panels.simplePanel(content)); diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatToolWindowScrollablePanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatToolWindowScrollablePanel.java index 29733898..2abc21d4 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatToolWindowScrollablePanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/ChatToolWindowScrollablePanel.java @@ -43,7 +43,7 @@ public class ChatToolWindowScrollablePanel extends ScrollablePanel { It looks like you haven't configured your API key yet. Visit CodeGPT settings to do so.

- Don't have an account? Sign up for free access to all open-source models. + Don't have an account? Sign up for free access to all models.

""", false, diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/UserMessagePanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/UserMessagePanel.java index efb75823..ec9ccb52 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/UserMessagePanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/UserMessagePanel.java @@ -102,6 +102,7 @@ public class UserMessagePanel extends JPanel { true, false, false, + false, parentDisposable) .withResponse(prompt); } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/TotalTokensPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/TotalTokensPanel.java index 827c70e1..82316d79 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/TotalTokensPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/TotalTokensPanel.java @@ -152,7 +152,7 @@ public class TotalTokensPanel extends JPanel { "Referenced Files Tokens", totalTokensDetails.getReferencedFilesTokens())) .entrySet().stream() .map(entry -> format( - "

%s: %d

", + "

%s: %d

", entry.getKey(), entry.getValue())) .collect(Collectors.joining()); @@ -165,14 +165,16 @@ public class TotalTokensPanel extends JPanel { private String getIconToolTipText(String html) { if (!GeneralSettings.isSelected(ServiceType.OPENAI)) { return """ - + + + %s +

- ⓘ Keep in mind that the output values might vary across different - large language models due to variations in their encoding methods. + Note: Output values might vary across different large language models + due to variations in their encoding methods.

- %s + """.formatted(html); } return ""; diff --git a/src/main/java/ee/carlrobert/codegpt/ui/checkbox/VirtualFileCheckboxTree.java b/src/main/java/ee/carlrobert/codegpt/ui/checkbox/VirtualFileCheckboxTree.java index 1f69a538..72d91702 100644 --- a/src/main/java/ee/carlrobert/codegpt/ui/checkbox/VirtualFileCheckboxTree.java +++ b/src/main/java/ee/carlrobert/codegpt/ui/checkbox/VirtualFileCheckboxTree.java @@ -1,10 +1,12 @@ package ee.carlrobert.codegpt.ui.checkbox; import com.intellij.icons.AllIcons; +import com.intellij.notification.NotificationType; import com.intellij.openapi.vfs.VirtualFile; import com.intellij.ui.CheckedTreeNode; import com.intellij.util.PlatformIcons; import ee.carlrobert.codegpt.ReferencedFile; +import ee.carlrobert.codegpt.ui.OverlayUtil; import java.io.File; import java.util.Arrays; import java.util.List; @@ -19,12 +21,20 @@ public class VirtualFileCheckboxTree extends FileCheckboxTree { public List getReferencedFiles() { var checkedNodes = getCheckedNodes(VirtualFile.class, Objects::nonNull); - if (checkedNodes.length > 1000) { + if (checkedNodes.length > 1024) { + OverlayUtil.showNotification("Too many files selected.", NotificationType.ERROR); throw new RuntimeException("Too many files selected"); } return Arrays.stream(checkedNodes) - .map(item -> new ReferencedFile(new File(item.getPath()))) + .map(item -> { + var file = new File(item.getPath()); + if (file.isFile()) { + return new ReferencedFile(file); + } + return null; + }) + .filter(Objects::nonNull) .toList(); } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt index 465b9a24..50fde985 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt @@ -49,4 +49,17 @@ abstract class BaseRequestFactory : CompletionRequestFactory { maxTokens: Int = 4096, stream: Boolean = false ): CompletionRequest + + protected fun getPromptWithFilesContext(callParameters: CallParameters): String { + return callParameters.referencedFiles?.let { + if (it.isEmpty()) { + callParameters.message.prompt + } else { + CompletionRequestUtil.getPromptWithContext( + it, + callParameters.message.prompt + ) + } + } ?: return callParameters.message.prompt + } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt index 532d43a4..2f8c723f 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt @@ -12,8 +12,11 @@ class AzureRequestFactory : BaseRequestFactory() { override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest { val configuration = service().state + val (callParameters) = params val requestBuilder: OpenAIChatCompletionRequest.Builder = - OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, params.callParameters)) + OpenAIChatCompletionRequest.Builder( + buildOpenAIMessages(null, callParameters, callParameters.referencedFiles) + ) .setMaxTokens(configuration.maxTokens) .setStream(true) .setTemperature(configuration.temperature.toDouble()) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt index ca3a4307..78f5509b 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt @@ -48,7 +48,10 @@ class ClaudeRequestFactory : BaseRequestFactory() { else -> { messages.add( - ClaudeCompletionStandardMessage("user", callParameters.message.prompt) + ClaudeCompletionStandardMessage( + "user", + getPromptWithFilesContext(callParameters) + ) ) } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt index bb34586d..353c7573 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt @@ -7,6 +7,8 @@ import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion. import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings +import ee.carlrobert.llm.client.openai.completion.request.Context +import ee.carlrobert.llm.client.openai.completion.request.FileContext import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDetails @@ -19,6 +21,7 @@ class CodeGPTRequestFactory : BaseRequestFactory() { val requestBuilder: OpenAIChatCompletionRequest.Builder = OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) .setModel(model) + if ("o1-mini" == model || "o1-preview" == model) { requestBuilder .setMaxCompletionTokens(configuration.maxTokens) @@ -42,6 +45,13 @@ class CodeGPTRequestFactory : BaseRequestFactory() { requestDocumentationDetails.url = documentationDetails.url requestBuilder.setDocumentationDetails(requestDocumentationDetails) } + callParameters.referencedFiles?.let { + val fileContexts = it.map { file -> + FileContext(file.fileName, file.fileContent) + } + requestBuilder.setContext(Context(fileContexts)) + } + return requestBuilder.build() } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt index 1532d449..8197d5a2 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt @@ -25,7 +25,11 @@ class CustomOpenAIRequestFactory : BaseRequestFactory() { service() .state .chatCompletionSettings, - OpenAIRequestFactory.buildOpenAIMessages(null, callParameters), + OpenAIRequestFactory.buildOpenAIMessages( + null, + callParameters, + callParameters.referencedFiles + ), true, getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY) ) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt index 5d8245eb..d95f2ade 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt @@ -159,7 +159,12 @@ class GoogleRequestFactory : BaseRequestFactory() { ) ) } else { - messages.add(GoogleCompletionContent("user", listOf(message.prompt))) + messages.add( + GoogleCompletionContent( + "user", + listOf(getPromptWithFilesContext(callParameters)) + ) + ) } return messages diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt index 8cc95637..085f6f14 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt @@ -24,7 +24,7 @@ class LlamaRequestFactory : BaseRequestFactory() { getSystemPrompt() val prompt = promptTemplate.buildPrompt( systemPrompt, - callParameters.message.prompt, + getPromptWithFilesContext(callParameters), callParameters.conversation.messages ) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt index 0dd2aa01..3ecebe3a 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt @@ -87,7 +87,15 @@ class OllamaRequestFactory : BaseRequestFactory() { } catch (e: IOException) { throw RuntimeException(e) } - } ?: messages.add(OllamaChatCompletionMessage("user", prevMessage.prompt, null)) + } ?: run { + messages.add( + OllamaChatCompletionMessage( + "user", + getPromptWithFilesContext(callParameters), + null + ) + ) + } messages.add(OllamaChatCompletionMessage("assistant", prevMessage.response, null)) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt index 56598fad..d83d71a2 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt @@ -2,6 +2,7 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.EncodingManager +import ee.carlrobert.codegpt.ReferencedFile import ee.carlrobert.codegpt.completions.* import ee.carlrobert.codegpt.completions.CompletionRequestUtil.EDIT_CODE_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT @@ -25,7 +26,9 @@ class OpenAIRequestFactory : CompletionRequestFactory { val model = service().state.model val configuration = service().state val requestBuilder: OpenAIChatCompletionRequest.Builder = - OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) + OpenAIChatCompletionRequest.Builder( + buildOpenAIMessages(model, callParameters, callParameters.referencedFiles) + ) .setModel(model) if ("o1-mini" == model || "o1-preview" == model) { requestBuilder @@ -41,6 +44,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { .setMaxTokens(configuration.maxTokens) .setTemperature(configuration.temperature.toDouble()) } + return requestBuilder.build() } @@ -98,9 +102,10 @@ class OpenAIRequestFactory : CompletionRequestFactory { fun buildOpenAIMessages( model: String?, - callParameters: CallParameters + callParameters: CallParameters, + referencedFiles: List? = mutableListOf() ): List { - val messages = buildOpenAIChatMessages(model, callParameters) + val messages = buildOpenAIChatMessages(model, callParameters, referencedFiles) if (model == null) { return messages @@ -134,7 +139,8 @@ class OpenAIRequestFactory : CompletionRequestFactory { private fun buildOpenAIChatMessages( model: String?, - callParameters: CallParameters + callParameters: CallParameters, + referencedFiles: List? = mutableListOf() ): MutableList { val message = callParameters.message val messages = mutableListOf() @@ -212,7 +218,12 @@ class OpenAIRequestFactory : CompletionRequestFactory { ) ) } else { - messages.add(OpenAIChatCompletionStandardMessage("user", message.prompt)) + val prompt = if (referencedFiles.isNullOrEmpty()) { + message.prompt + } else { + CompletionRequestUtil.getPromptWithContext(referencedFiles, message.prompt) + } + messages.add(OpenAIChatCompletionStandardMessage("user", prompt)) } return messages } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/events/CodeGPTEvent.kt b/src/main/kotlin/ee/carlrobert/codegpt/events/CodeGPTEvent.kt index 6fdc97cd..c4b19706 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/events/CodeGPTEvent.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/events/CodeGPTEvent.kt @@ -25,6 +25,7 @@ data class Event @JsonCreator constructor( ANALYZE_WEB_DOC_STARTED, ANALYZE_WEB_DOC_COMPLETED, ANALYZE_WEB_DOC_FAILED, + PROCESS_CONTEXT, WEB_SEARCH_ITEM } } @@ -52,6 +53,13 @@ data class AnalysisFailedEventDetails( val error: String ) : EventDetails +data class ProcessContextEventDetails( + val id: UUID, + val name: String, + val description: String, + val status: String +) : EventDetails + data class DefaultEventDetails( val id: UUID, val name: String, @@ -78,6 +86,10 @@ class EventDeserializer : StdDeserializer(Event::class.java) { objectMapper.treeToValue(detailsNode, AnalysisFailedEventDetails::class.java) } + EventType.PROCESS_CONTEXT -> { + objectMapper.treeToValue(detailsNode, ProcessContextEventDetails::class.java) + } + else -> { objectMapper.treeToValue(detailsNode, DefaultEventDetails::class.java) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/util/file/FileUtil.kt b/src/main/kotlin/ee/carlrobert/codegpt/util/file/FileUtil.kt index ecdb849e..db147036 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/util/file/FileUtil.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/util/file/FileUtil.kt @@ -160,9 +160,9 @@ object FileUtil { } @JvmStatic - fun getResourceContent(name: String?): String { + fun getResourceContent(filePath: String?): String { try { - Objects.requireNonNull(name?.let { FileUtil::class.java.getResourceAsStream(it) }) + Objects.requireNonNull(filePath?.let { FileUtil::class.java.getResourceAsStream(it) }) .use { stream -> return String(stream.readAllBytes(), StandardCharsets.UTF_8) }