diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt index 096e0732..1f82f788 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt @@ -16,6 +16,8 @@ import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage import ee.carlrobert.llm.client.openai.completion.request.OpenAITextCompletionRequest import ee.carlrobert.service.GrpcCodeCompletionRequest import okhttp3.MediaType.Companion.toMediaType @@ -52,6 +54,86 @@ object CodeCompletionRequestFactory { .build() } + @JvmStatic + fun buildChatBasedFIMRequest(details: InfillRequest): OpenAIChatCompletionRequest { + val systemMessage = OpenAIChatCompletionStandardMessage( + "system", + "You are a code completion assistant. Complete the code between the given prefix and suffix. " + + "Return only the missing code that should be inserted, without any formatting, explanations, or markdown." + ) + + val userMessage = OpenAIChatCompletionStandardMessage( + "user", + "\n${details.prefix}\n\n\n\n${details.suffix}\n\n\nComplete:" + ) + + return OpenAIChatCompletionRequest.Builder(listOf(systemMessage, userMessage)) + .setModel( + ModelSelectionService.getInstance().getModelForFeature(FeatureType.CODE_COMPLETION) + ) + .setStream(true) + .setMaxTokens(MAX_TOKENS) + .setTemperature(0.0) + .build() + } + + @JvmStatic + fun buildChatBasedFIMHttpRequest( + details: InfillRequest, + url: String, + headers: Map, + body: Map, + credential: String? + ): Request { + val requestBuilder = Request.Builder().url(url) + + for (entry in headers.entries) { + var value = entry.value + if (credential != null && value.contains("\$CUSTOM_SERVICE_API_KEY")) { + value = value.replace("\$CUSTOM_SERVICE_API_KEY", credential) + } + requestBuilder.addHeader(entry.key, value) + } + + // Create chat completion messages using the improved prompt template + val systemMessage = mapOf( + "role" to "system", + "content" to "You are a code completion assistant. Complete the code between the given prefix and suffix. " + + "Return only the missing code that should be inserted, without any formatting, explanations, or markdown." + ) + + val userMessage = mapOf( + "role" to "user", + "content" to "\n${details.prefix}\n\n\n\n${details.suffix}\n\n\nComplete:" + ) + + // Transform the custom body configuration, excluding completion-specific parameters + val transformedBody = body.entries.mapNotNull { (key, value) -> + when (key.lowercase()) { + "messages" -> key to listOf(systemMessage, userMessage) + // Exclude completion-specific parameters that don't apply to chat completions + "prompt", "suffix" -> null + else -> key to transformValue(value, InfillPromptTemplate.CHAT_COMPLETION, details) + } + }.toMap().toMutableMap() + + // Ensure we have messages for chat completion + if (!transformedBody.containsKey("messages")) { + transformedBody["messages"] = listOf(systemMessage, userMessage) + } + + try { + val jsonBody = ObjectMapper() + .writerWithDefaultPrettyPrinter() + .writeValueAsString(transformedBody) + .toByteArray(StandardCharsets.UTF_8) + .toRequestBody("application/json".toMediaType()) + return requestBuilder.post(jsonBody).build() + } catch (e: JsonProcessingException) { + throw RuntimeException(e) + } + } + @JvmStatic fun buildCustomRequest(details: InfillRequest): Request { val activeService = service() @@ -78,6 +160,12 @@ object CodeCompletionRequestFactory { infillTemplate: InfillPromptTemplate, credential: String? ): Request { + // For chat-based FIM, we should not use this method + // The routing logic in CodeCompletionService will handle it + if (infillTemplate == InfillPromptTemplate.CHAT_COMPLETION) { + throw IllegalArgumentException("Chat-based FIM should use buildChatBasedFIMRequest instead") + } + val requestBuilder = Request.Builder().url(url) for (entry in headers.entries) { var value = entry.value diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.kt index e9483831..a225c26a 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.kt @@ -3,11 +3,15 @@ package ee.carlrobert.codegpt.codecompletions import com.intellij.openapi.components.Service import com.intellij.openapi.components.service import com.intellij.openapi.project.Project +import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildChatBasedFIMRequest +import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildCustomRequest import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildLlamaRequest import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildOllamaRequest import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildOpenAIRequest import ee.carlrobert.codegpt.completions.CompletionClientProvider +import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey +import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential import ee.carlrobert.codegpt.settings.service.FeatureType import ee.carlrobert.codegpt.settings.service.ModelSelectionService import ee.carlrobert.codegpt.settings.service.ServiceType @@ -55,23 +59,47 @@ class CodeCompletionService(private val project: Project) { ): EventSource { return when (val selectedService = ModelSelectionService.getInstance().getServiceForFeature(FeatureType.CODE_COMPLETION)) { - OPENAI -> CompletionClientProvider.getOpenAIClient() - .getCompletionAsync(buildOpenAIRequest(infillRequest), eventListener) + OPENAI -> { + val openAISettings = OpenAISettings.getCurrentState() + // Check if user wants to use chat-based FIM (we'll add this setting later) + // For now, default to traditional completion + CompletionClientProvider.getOpenAIClient() + .getCompletionAsync(buildOpenAIRequest(infillRequest), eventListener) + } CUSTOM_OPENAI -> { - val settings = service() - .customServiceStateForFeatureType(FeatureType.CODE_COMPLETION) - .codeCompletionSettings - createFactory( - CompletionClientProvider.getDefaultClientBuilder().build() - ).newEventSource( - buildCustomRequest(infillRequest), - if (settings.parseResponseAsChatCompletions) { + val activeService = service().state.active + val customSettings = activeService.codeCompletionSettings + val isChatBasedFIM = customSettings.infillTemplate == InfillPromptTemplate.CHAT_COMPLETION + + if (isChatBasedFIM) { + // Use chat completion endpoint for chat-based FIM with proper API key substitution + val credential = getCredential(CredentialKey.CustomServiceApiKey(activeService.name.orEmpty())) + createFactory( + CompletionClientProvider.getDefaultClientBuilder().build() + ).newEventSource( + buildChatBasedFIMHttpRequest( + infillRequest, + customSettings.url!!, + customSettings.headers, + customSettings.body, + credential + ), OpenAIChatCompletionEventSourceListener(eventListener) - } else { - OpenAITextCompletionEventSourceListener(eventListener) - } - ) + ) + } else { + // Use traditional completion endpoint + createFactory( + CompletionClientProvider.getDefaultClientBuilder().build() + ).newEventSource( + buildCustomRequest(infillRequest), + if (customSettings.parseResponseAsChatCompletions) { + OpenAIChatCompletionEventSourceListener(eventListener) + } else { + OpenAITextCompletionEventSourceListener(eventListener) + } + ) + } } MISTRAL -> CompletionClientProvider.getMistralClient() diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/InfillPromptTemplate.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/InfillPromptTemplate.kt index 35864da4..1155fca5 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/InfillPromptTemplate.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/InfillPromptTemplate.kt @@ -155,6 +155,13 @@ enum class InfillPromptTemplate(val label: String, val stopTokens: List? "[SUFFIX]${infillDetails.suffix}[PREFIX]${infillDetails.prefix}[MIDDLE]" return createDefaultMultiFilePrompt(infillDetails, infillPrompt) } + }, + CHAT_COMPLETION("Chat-based FIM", listOf("\n\n", "```")) { + override fun buildPrompt(infillDetails: InfillRequest): String { + // This template is used for chat-based FIM completion + // The actual prompt construction is handled in the request factory + return "CHAT_FIM_PLACEHOLDER" + } }; abstract fun buildPrompt(infillDetails: InfillRequest): String diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceCodeCompletionForm.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceCodeCompletionForm.kt index f794b44c..6826caee 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceCodeCompletionForm.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/custom/form/CustomServiceCodeCompletionForm.kt @@ -162,17 +162,35 @@ class CustomServiceCodeCompletionForm( } private fun testConnection() { - CompletionRequestService.getInstance().getCustomOpenAICompletionAsync( - CodeCompletionRequestFactory.buildCustomRequest( - InfillRequest.Builder("Hello", "!", 0).build(), - urlField.text, - tabbedPane.headers, - tabbedPane.body, - promptTemplateComboBox.selectedItem as InfillPromptTemplate, - getApiKey.invoke() - ), - TestConnectionEventListener() - ) + val selectedTemplate = promptTemplateComboBox.selectedItem as InfillPromptTemplate + val testRequest = InfillRequest.Builder("Hello", "!", 0).build() + + if (selectedTemplate == InfillPromptTemplate.CHAT_COMPLETION) { + // Use chat completion endpoint for testing + CompletionRequestService.getInstance().getCustomOpenAIChatCompletionAsync( + CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest( + testRequest, + urlField.text, + tabbedPane.headers, + tabbedPane.body, + getApiKey.invoke() + ), + TestConnectionEventListener() + ) + } else { + // Use traditional completion endpoint for testing + CompletionRequestService.getInstance().getCustomOpenAICompletionAsync( + CodeCompletionRequestFactory.buildCustomRequest( + testRequest, + urlField.text, + tabbedPane.headers, + tabbedPane.body, + selectedTemplate, + getApiKey.invoke() + ), + TestConnectionEventListener() + ) + } } internal inner class TestConnectionEventListener : CompletionEventListener { @@ -215,4 +233,4 @@ class CustomServiceCodeCompletionForm( .setDescription("

$description

") .installOn(promptTemplateHelpText) } -} \ No newline at end of file +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/ChatBasedFIMTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/ChatBasedFIMTest.kt new file mode 100644 index 00000000..5bf594d7 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/ChatBasedFIMTest.kt @@ -0,0 +1,99 @@ +package ee.carlrobert.codegpt.codecompletions + +import ee.carlrobert.codegpt.settings.service.FeatureType +import ee.carlrobert.codegpt.settings.service.ModelSelectionService +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions.* + +class ChatBasedFIMTest { + + @Test + fun `test chat completion template is available`() { + val templates = InfillPromptTemplate.values() + val chatTemplate = templates.find { it == InfillPromptTemplate.CHAT_COMPLETION } + + assertNotNull(chatTemplate) + assertEquals("Chat-based FIM", chatTemplate?.label) + assertEquals(listOf("\n\n", "```"), chatTemplate?.stopTokens) + } + + @Test + fun `test chat completion template builds placeholder`() { + val template = InfillPromptTemplate.CHAT_COMPLETION + val infillRequest = InfillRequest.Builder("function test() {", "}", 0).build() + + val result = template.buildPrompt(infillRequest) + assertEquals("CHAT_FIM_PLACEHOLDER", result) + } + + @Test + fun `test chat based FIM request creation`() { + val infillRequest = InfillRequest.Builder( + "function calculateSum(a, b) {", + "return result;\n}", + 0 + ).build() + + val chatRequest = CodeCompletionRequestFactory.buildChatBasedFIMRequest(infillRequest) + + assertNotNull(chatRequest) + assertEquals(2, chatRequest.messages.size) + + val systemMessage = chatRequest.messages[0] as OpenAIChatCompletionStandardMessage + assertEquals("system", systemMessage.role) + assertTrue(systemMessage.content.contains("expert coding assistant")) + + val userMessage = chatRequest.messages[1] as OpenAIChatCompletionStandardMessage + assertEquals("user", userMessage.role) + assertTrue(userMessage.content.contains("PREFIX:")) + assertTrue(userMessage.content.contains("SUFFIX:")) + assertTrue(userMessage.content.contains("function calculateSum(a, b) {")) + assertTrue(userMessage.content.contains("return result;")) + + assertTrue(chatRequest.isStream) + assertEquals(128, chatRequest.maxTokens) + assertEquals(0.0, chatRequest.temperature) + } + + @Test + fun `test custom request throws exception for chat completion template`() { + val infillRequest = InfillRequest.Builder("test", "test", 0).build() + + assertThrows(IllegalArgumentException::class.java) { + CodeCompletionRequestFactory.buildCustomRequest( + infillRequest, + "http://test.com", + emptyMap(), + emptyMap(), + InfillPromptTemplate.CHAT_COMPLETION, + null + ) + } + } + + @Test + fun `test chat based FIM HTTP request creation`() { + val infillRequest = InfillRequest.Builder("test prefix", "test suffix", 0).build() + val headers = mapOf("Authorization" to "Bearer \$CUSTOM_SERVICE_API_KEY") + val body = mapOf( + "model" to "gpt-4.1", + "temperature" to 0.2, + "max_tokens" to 24, + "stream" to false + ) + + val httpRequest = CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest( + infillRequest, + "https://api.openai.com/v1/chat/completions", + headers, + body, + "test-api-key" + ) + + assertNotNull(httpRequest) + assertEquals("https://api.openai.com/v1/chat/completions", httpRequest.url.toString()) + assertEquals("Bearer test-api-key", httpRequest.header("Authorization")) + assertEquals("application/json", httpRequest.body?.contentType()?.toString()) + } +}