From 020e1aff7f294394be7f59adbedb5cd5f7c86026 Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Wed, 9 Oct 2024 11:56:22 +0300 Subject: [PATCH] chore: include metadata in codegpt requests --- gradle/libs.versions.toml | 2 +- .../codegpt/completions/CallParameters.java | 15 +++- .../completions/CompletionRequestService.java | 24 ++++--- .../codegpt/toolwindow/chat/ChatSession.java | 34 +++++++++ .../chat/ChatToolWindowTabPanel.java | 20 +++++- .../factory/CodeGPTRequestFactory.kt | 72 +++++++++++++------ .../CompletionRequestProviderTest.kt | 33 ++------- 7 files changed, 136 insertions(+), 64 deletions(-) create mode 100644 src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatSession.java diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c0f2c7d5..c16dd51d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -12,7 +12,7 @@ jsoup = "1.17.2" jtokkit = "1.1.0" junit = "5.11.0" kotlin = "2.0.0" -llm-client = "0.8.21" +llm-client = "0.8.22" okio = "3.9.0" tree-sitter = "0.22.6a" diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java b/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java index 877563ed..791c7404 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CallParameters.java @@ -4,10 +4,12 @@ import ee.carlrobert.codegpt.ReferencedFile; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.message.Message; import java.util.List; +import java.util.UUID; import org.jetbrains.annotations.Nullable; public class CallParameters { + private final UUID sessionId; private final Conversation conversation; private final ConversationType conversationType; private final Message message; @@ -18,15 +20,22 @@ public class CallParameters { private List referencedFiles; public CallParameters(Conversation conversation, Message message) { - this(conversation, ConversationType.DEFAULT, message, null, false); + this(null, conversation, message); } + public CallParameters(UUID sessionId, Conversation conversation, Message message) { + this(sessionId, conversation, ConversationType.DEFAULT, message, null, false); + } + + // TODO: Builder public CallParameters( + UUID sessionId, Conversation conversation, ConversationType conversationType, Message message, @Nullable String highlightedText, boolean retry) { + this.sessionId = sessionId; this.conversation = conversation; this.conversationType = conversationType; this.message = message; @@ -34,6 +43,10 @@ public class CallParameters { this.retry = retry; } + public UUID getSessionId() { + return sessionId; + } + public Conversation getConversation() { return conversation; } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 56c06c5b..63bfe386 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -15,6 +15,7 @@ import ee.carlrobert.codegpt.settings.service.azure.AzureSettings; import ee.carlrobert.codegpt.settings.service.google.GoogleSettings; import ee.carlrobert.llm.client.DeserializationUtil; import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest; +import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest; import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest; import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest; @@ -98,13 +99,6 @@ public final class CompletionRequestService { CompletionEventListener eventListener) { if (request instanceof OpenAIChatCompletionRequest completionRequest) { return switch (GeneralSettings.getSelectedService()) { - case CODEGPT -> { - if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) { - yield getO1ChatCompletionAsync(completionRequest, eventListener); - } - yield CompletionClientProvider.getCodeGPTClient() - .getChatCompletionAsync(completionRequest, eventListener); - } case OPENAI -> { if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) { yield getO1ChatCompletionAsync(completionRequest, eventListener); @@ -117,6 +111,13 @@ public final class CompletionRequestService { default -> throw new RuntimeException("Unknown service selected"); }; } + if (request instanceof ChatCompletionRequest completionRequest) { + if (List.of("o1-mini", "o1-preview").contains(completionRequest.getModel())) { + return getO1ChatCompletionAsync(completionRequest, eventListener); + } + return CompletionClientProvider.getCodeGPTClient() + .getChatCompletionAsync(completionRequest, eventListener); + } if (request instanceof CustomOpenAIRequest completionRequest) { return getCustomOpenAIChatCompletionAsync(completionRequest.getRequest(), eventListener); } @@ -148,7 +149,7 @@ public final class CompletionRequestService { } private EventSource getO1ChatCompletionAsync( - OpenAIChatCompletionRequest request, + CompletionRequest request, CompletionEventListener eventListener) { ProgressManager.getInstance() .run(new Task.Backgroundable(null, "CodeGPT: Processing o1 request") { @@ -176,8 +177,6 @@ public final class CompletionRequestService { public String getChatCompletion(CompletionRequest request) { if (request instanceof OpenAIChatCompletionRequest completionRequest) { var response = switch (GeneralSettings.getSelectedService()) { - case CODEGPT -> CompletionClientProvider.getCodeGPTClient() - .getChatCompletion(completionRequest); case OPENAI -> CompletionClientProvider.getOpenAIClient() .getChatCompletion(completionRequest); case AZURE -> CompletionClientProvider.getAzureClient() @@ -186,6 +185,11 @@ public final class CompletionRequestService { }; return tryExtractContent(response).orElseThrow(); } + if (request instanceof ChatCompletionRequest completionRequest) { + var response = + CompletionClientProvider.getCodeGPTClient().getChatCompletion(completionRequest); + return tryExtractContent(response).orElseThrow(); + } if (request instanceof CustomOpenAIRequest completionRequest) { var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); try (var response = httpClient.newCall(completionRequest.getRequest()).execute()) { diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatSession.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatSession.java new file mode 100644 index 00000000..b9dd8147 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatSession.java @@ -0,0 +1,34 @@ +package ee.carlrobert.codegpt.toolwindow.chat; + +import ee.carlrobert.codegpt.ReferencedFile; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +public class ChatSession { + + private final UUID id; + private final Set referencedFiles; + + public ChatSession() { + this.id = UUID.randomUUID(); + this.referencedFiles = new HashSet<>(); + } + + public UUID getId() { + return id; + } + + public List getReferencedFiles() { + return new ArrayList<>(referencedFiles); + } + + public void addReferencedFiles(List files) { + if (files == null) { + throw new IllegalArgumentException("Referenced files cannot be null"); + } + referencedFiles.addAll(files); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java index e75d6c59..d1817ab9 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanel.java @@ -51,6 +51,8 @@ public class ChatToolWindowTabPanel implements Disposable { private static final Logger LOG = Logger.getInstance(ChatToolWindowTabPanel.class); + private final ChatSession chatSession; + private final Project project; private final JPanel rootPanel; private final Conversation conversation; @@ -64,6 +66,7 @@ public class ChatToolWindowTabPanel implements Disposable { public ChatToolWindowTabPanel(@NotNull Project project, @NotNull Conversation conversation) { this.project = project; this.conversation = conversation; + this.chatSession = new ChatSession(); conversationService = ConversationService.getInstance(); toolWindowScrollablePanel = new ChatToolWindowScrollablePanel(); totalTokensPanel = new TotalTokensPanel( @@ -165,8 +168,13 @@ public class ChatToolWindowTabPanel implements Disposable { Message message, @Nullable String highlightedText, @Nullable String attachedFilePath) { - var callParameters = new CallParameters(conversation, conversationType, message, - highlightedText, false); + var callParameters = new CallParameters( + chatSession.getId(), + conversation, + conversationType, + message, + highlightedText, + false); if (attachedFilePath != null && !attachedFilePath.isEmpty()) { try { callParameters.setImageData(Files.readAllBytes(Path.of(attachedFilePath))); @@ -227,7 +235,13 @@ public class ChatToolWindowTabPanel implements Disposable { if (responsePanel != null) { message.setResponse(""); conversationService.saveMessage(conversation, message); - call(new CallParameters(conversation, conversationType, message, null, true), + call(new CallParameters( + chatSession.getId(), + conversation, + conversationType, + message, + null, + true), responsePanel); } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt index 353c7573..44218d09 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt @@ -1,30 +1,35 @@ package ee.carlrobert.codegpt.completions.factory +import com.intellij.openapi.application.ApplicationInfo import com.intellij.openapi.components.service +import ee.carlrobert.codegpt.CodeGPTPlugin import ee.carlrobert.codegpt.completions.BaseRequestFactory import ee.carlrobert.codegpt.completions.ChatCompletionRequestParameters -import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildBasicO1Request import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings -import ee.carlrobert.llm.client.openai.completion.request.Context -import ee.carlrobert.llm.client.openai.completion.request.FileContext -import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest -import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDetails +import ee.carlrobert.llm.client.codegpt.request.chat.* +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage class CodeGPTRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionRequestParameters): OpenAIChatCompletionRequest { + override fun createChatRequest(params: ChatCompletionRequestParameters): ChatCompletionRequest { val (callParameters) = params val model = service().state.chatCompletionSettings.model val configuration = service().state - val requestBuilder: OpenAIChatCompletionRequest.Builder = - OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) + val requestBuilder: ChatCompletionRequest.Builder = + ChatCompletionRequest.Builder(buildOpenAIMessages(model, callParameters)) .setModel(model) + .setSessionId(callParameters.sessionId) + .setMetadata( + Metadata( + CodeGPTPlugin.getVersion(), + service().build.asString() + ) + ) if ("o1-mini" == model || "o1-preview" == model) { requestBuilder - .setMaxCompletionTokens(configuration.maxTokens) .setStream(false) .setMaxTokens(null) .setTemperature(null) @@ -40,16 +45,15 @@ class CodeGPTRequestFactory : BaseRequestFactory() { } val documentationDetails = callParameters.message.documentationDetails if (documentationDetails != null) { - val requestDocumentationDetails = RequestDocumentationDetails() - requestDocumentationDetails.name = documentationDetails.name - requestDocumentationDetails.url = documentationDetails.url - requestBuilder.setDocumentationDetails(requestDocumentationDetails) + requestBuilder.setDocumentationDetails( + DocumentationDetails(documentationDetails.name, documentationDetails.url) + ) } callParameters.referencedFiles?.let { val fileContexts = it.map { file -> - FileContext(file.fileName, file.fileContent) + ContextFile(file.fileName, file.fileContent) } - requestBuilder.setContext(Context(fileContexts)) + requestBuilder.setContext(AdditionalRequestContext(fileContexts)) } return requestBuilder.build() @@ -60,16 +64,42 @@ class CodeGPTRequestFactory : BaseRequestFactory() { userPrompt: String, maxTokens: Int, stream: Boolean - ): OpenAIChatCompletionRequest { + ): ChatCompletionRequest { val model = service().state.chatCompletionSettings.model if (model == "o1-mini" || model == "o1-preview") { return buildBasicO1Request(model, userPrompt, systemPrompt, maxTokens) } - return OpenAIRequestFactory.createBasicCompletionRequest( - systemPrompt, - userPrompt, - model, - stream + + return ChatCompletionRequest.Builder( + listOf( + OpenAIChatCompletionStandardMessage("system", systemPrompt), + OpenAIChatCompletionStandardMessage("user", userPrompt) + ) ) + .setModel(model) + .setStream(stream) + .build() + } + + fun buildBasicO1Request( + model: String, + prompt: String, + systemPrompt: String = "", + maxCompletionTokens: Int = 4096 + ): ChatCompletionRequest { + val messages = if (systemPrompt.isEmpty()) { + listOf(OpenAIChatCompletionStandardMessage("user", prompt)) + } else { + listOf( + OpenAIChatCompletionStandardMessage("user", systemPrompt), + OpenAIChatCompletionStandardMessage("user", prompt) + ) + } + return ChatCompletionRequest.Builder(messages) + .setModel(model) + .setMaxTokens(maxCompletionTokens) + .setStream(false) + .setTemperature(null) + .build() } } diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt index bb2b6690..bf0cab60 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt @@ -25,13 +25,7 @@ class CompletionRequestProviderTest : IntegrationTest() { val request = OpenAIRequestFactory().createChatRequest( ChatCompletionRequestParameters( - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false - ) + CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) ) ) @@ -58,13 +52,7 @@ class CompletionRequestProviderTest : IntegrationTest() { val request = OpenAIRequestFactory().createChatRequest( ChatCompletionRequestParameters( - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false - ) + CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) ) ) @@ -92,6 +80,7 @@ class CompletionRequestProviderTest : IntegrationTest() { val request = OpenAIRequestFactory().createChatRequest( ChatCompletionRequestParameters( CallParameters( + null, conversation, ConversationType.DEFAULT, secondMessage, @@ -125,13 +114,7 @@ class CompletionRequestProviderTest : IntegrationTest() { val request = OpenAIRequestFactory().createChatRequest( ChatCompletionRequestParameters( - CallParameters( - conversation, - ConversationType.DEFAULT, - Message("TEST_CHAT_COMPLETION_PROMPT"), - null, - false - ) + CallParameters(conversation, Message("TEST_CHAT_COMPLETION_PROMPT")) ) ) @@ -155,13 +138,7 @@ class CompletionRequestProviderTest : IntegrationTest() { assertThrows(TotalUsageExceededException::class.java) { OpenAIRequestFactory().createChatRequest( ChatCompletionRequestParameters( - CallParameters( - conversation, - ConversationType.DEFAULT, - createDummyMessage(100), - null, - false - ) + CallParameters(conversation, createDummyMessage(100)) ) ) }