feat: chat-based FIM code completion for Custom OpenAI providers (#1103)

* Implement FIM code completion with a chat models

* Fix chat-based FIM

* fix: build

---------

Co-authored-by: Carl-Robert Linnupuu <carlrobertoh@gmail.com>
This commit is contained in:
Gustavo Montamat 2025-09-03 06:39:45 -04:00 committed by Carl-Robert Linnupuu
parent 8f020a26eb
commit d9973be131
5 changed files with 266 additions and 26 deletions

View file

@ -16,6 +16,8 @@ import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage
import ee.carlrobert.llm.client.openai.completion.request.OpenAITextCompletionRequest
import ee.carlrobert.service.GrpcCodeCompletionRequest
import okhttp3.MediaType.Companion.toMediaType
@ -52,6 +54,86 @@ object CodeCompletionRequestFactory {
.build()
}
@JvmStatic
fun buildChatBasedFIMRequest(details: InfillRequest): OpenAIChatCompletionRequest {
val systemMessage = OpenAIChatCompletionStandardMessage(
"system",
"You are a code completion assistant. Complete the code between the given prefix and suffix. " +
"Return only the missing code that should be inserted, without any formatting, explanations, or markdown."
)
val userMessage = OpenAIChatCompletionStandardMessage(
"user",
"<PREFIX>\n${details.prefix}\n</PREFIX>\n\n<SUFFIX>\n${details.suffix}\n</SUFFIX>\n\nComplete:"
)
return OpenAIChatCompletionRequest.Builder(listOf(systemMessage, userMessage))
.setModel(
ModelSelectionService.getInstance().getModelForFeature(FeatureType.CODE_COMPLETION)
)
.setStream(true)
.setMaxTokens(MAX_TOKENS)
.setTemperature(0.0)
.build()
}
@JvmStatic
fun buildChatBasedFIMHttpRequest(
details: InfillRequest,
url: String,
headers: Map<String, String>,
body: Map<String, Any>,
credential: String?
): Request {
val requestBuilder = Request.Builder().url(url)
for (entry in headers.entries) {
var value = entry.value
if (credential != null && value.contains("\$CUSTOM_SERVICE_API_KEY")) {
value = value.replace("\$CUSTOM_SERVICE_API_KEY", credential)
}
requestBuilder.addHeader(entry.key, value)
}
// Create chat completion messages using the improved prompt template
val systemMessage = mapOf<String, String>(
"role" to "system",
"content" to "You are a code completion assistant. Complete the code between the given prefix and suffix. " +
"Return only the missing code that should be inserted, without any formatting, explanations, or markdown."
)
val userMessage = mapOf<String, String>(
"role" to "user",
"content" to "<PREFIX>\n${details.prefix}\n</PREFIX>\n\n<SUFFIX>\n${details.suffix}\n</SUFFIX>\n\nComplete:"
)
// Transform the custom body configuration, excluding completion-specific parameters
val transformedBody = body.entries.mapNotNull { (key, value) ->
when (key.lowercase()) {
"messages" -> key to listOf(systemMessage, userMessage)
// Exclude completion-specific parameters that don't apply to chat completions
"prompt", "suffix" -> null
else -> key to transformValue(value, InfillPromptTemplate.CHAT_COMPLETION, details)
}
}.toMap().toMutableMap()
// Ensure we have messages for chat completion
if (!transformedBody.containsKey("messages")) {
transformedBody["messages"] = listOf(systemMessage, userMessage)
}
try {
val jsonBody = ObjectMapper()
.writerWithDefaultPrettyPrinter()
.writeValueAsString(transformedBody)
.toByteArray(StandardCharsets.UTF_8)
.toRequestBody("application/json".toMediaType())
return requestBuilder.post(jsonBody).build()
} catch (e: JsonProcessingException) {
throw RuntimeException(e)
}
}
@JvmStatic
fun buildCustomRequest(details: InfillRequest): Request {
val activeService = service<CustomServicesSettings>()
@ -78,6 +160,12 @@ object CodeCompletionRequestFactory {
infillTemplate: InfillPromptTemplate,
credential: String?
): Request {
// For chat-based FIM, we should not use this method
// The routing logic in CodeCompletionService will handle it
if (infillTemplate == InfillPromptTemplate.CHAT_COMPLETION) {
throw IllegalArgumentException("Chat-based FIM should use buildChatBasedFIMRequest instead")
}
val requestBuilder = Request.Builder().url(url)
for (entry in headers.entries) {
var value = entry.value

View file

@ -3,11 +3,15 @@ package ee.carlrobert.codegpt.codecompletions
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildChatBasedFIMRequest
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildCustomRequest
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildLlamaRequest
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildOllamaRequest
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildOpenAIRequest
import ee.carlrobert.codegpt.completions.CompletionClientProvider
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
import ee.carlrobert.codegpt.settings.service.FeatureType
import ee.carlrobert.codegpt.settings.service.ModelSelectionService
import ee.carlrobert.codegpt.settings.service.ServiceType
@ -55,23 +59,47 @@ class CodeCompletionService(private val project: Project) {
): EventSource {
return when (val selectedService =
ModelSelectionService.getInstance().getServiceForFeature(FeatureType.CODE_COMPLETION)) {
OPENAI -> CompletionClientProvider.getOpenAIClient()
.getCompletionAsync(buildOpenAIRequest(infillRequest), eventListener)
OPENAI -> {
val openAISettings = OpenAISettings.getCurrentState()
// Check if user wants to use chat-based FIM (we'll add this setting later)
// For now, default to traditional completion
CompletionClientProvider.getOpenAIClient()
.getCompletionAsync(buildOpenAIRequest(infillRequest), eventListener)
}
CUSTOM_OPENAI -> {
val settings = service<CustomServicesSettings>()
.customServiceStateForFeatureType(FeatureType.CODE_COMPLETION)
.codeCompletionSettings
createFactory(
CompletionClientProvider.getDefaultClientBuilder().build()
).newEventSource(
buildCustomRequest(infillRequest),
if (settings.parseResponseAsChatCompletions) {
val activeService = service<CustomServicesSettings>().state.active
val customSettings = activeService.codeCompletionSettings
val isChatBasedFIM = customSettings.infillTemplate == InfillPromptTemplate.CHAT_COMPLETION
if (isChatBasedFIM) {
// Use chat completion endpoint for chat-based FIM with proper API key substitution
val credential = getCredential(CredentialKey.CustomServiceApiKey(activeService.name.orEmpty()))
createFactory(
CompletionClientProvider.getDefaultClientBuilder().build()
).newEventSource(
buildChatBasedFIMHttpRequest(
infillRequest,
customSettings.url!!,
customSettings.headers,
customSettings.body,
credential
),
OpenAIChatCompletionEventSourceListener(eventListener)
} else {
OpenAITextCompletionEventSourceListener(eventListener)
}
)
)
} else {
// Use traditional completion endpoint
createFactory(
CompletionClientProvider.getDefaultClientBuilder().build()
).newEventSource(
buildCustomRequest(infillRequest),
if (customSettings.parseResponseAsChatCompletions) {
OpenAIChatCompletionEventSourceListener(eventListener)
} else {
OpenAITextCompletionEventSourceListener(eventListener)
}
)
}
}
MISTRAL -> CompletionClientProvider.getMistralClient()

View file

@ -155,6 +155,13 @@ enum class InfillPromptTemplate(val label: String, val stopTokens: List<String>?
"[SUFFIX]${infillDetails.suffix}[PREFIX]${infillDetails.prefix}[MIDDLE]"
return createDefaultMultiFilePrompt(infillDetails, infillPrompt)
}
},
CHAT_COMPLETION("Chat-based FIM", listOf("\n\n", "```")) {
override fun buildPrompt(infillDetails: InfillRequest): String {
// This template is used for chat-based FIM completion
// The actual prompt construction is handled in the request factory
return "CHAT_FIM_PLACEHOLDER"
}
};
abstract fun buildPrompt(infillDetails: InfillRequest): String

View file

@ -162,17 +162,35 @@ class CustomServiceCodeCompletionForm(
}
private fun testConnection() {
CompletionRequestService.getInstance().getCustomOpenAICompletionAsync(
CodeCompletionRequestFactory.buildCustomRequest(
InfillRequest.Builder("Hello", "!", 0).build(),
urlField.text,
tabbedPane.headers,
tabbedPane.body,
promptTemplateComboBox.selectedItem as InfillPromptTemplate,
getApiKey.invoke()
),
TestConnectionEventListener()
)
val selectedTemplate = promptTemplateComboBox.selectedItem as InfillPromptTemplate
val testRequest = InfillRequest.Builder("Hello", "!", 0).build()
if (selectedTemplate == InfillPromptTemplate.CHAT_COMPLETION) {
// Use chat completion endpoint for testing
CompletionRequestService.getInstance().getCustomOpenAIChatCompletionAsync(
CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest(
testRequest,
urlField.text,
tabbedPane.headers,
tabbedPane.body,
getApiKey.invoke()
),
TestConnectionEventListener()
)
} else {
// Use traditional completion endpoint for testing
CompletionRequestService.getInstance().getCustomOpenAICompletionAsync(
CodeCompletionRequestFactory.buildCustomRequest(
testRequest,
urlField.text,
tabbedPane.headers,
tabbedPane.body,
selectedTemplate,
getApiKey.invoke()
),
TestConnectionEventListener()
)
}
}
internal inner class TestConnectionEventListener : CompletionEventListener<String?> {
@ -215,4 +233,4 @@ class CustomServiceCodeCompletionForm(
.setDescription("<html><p>$description</p></html>")
.installOn(promptTemplateHelpText)
}
}
}

View file

@ -0,0 +1,99 @@
package ee.carlrobert.codegpt.codecompletions
import ee.carlrobert.codegpt.settings.service.FeatureType
import ee.carlrobert.codegpt.settings.service.ModelSelectionService
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Assertions.*
class ChatBasedFIMTest {
@Test
fun `test chat completion template is available`() {
val templates = InfillPromptTemplate.values()
val chatTemplate = templates.find { it == InfillPromptTemplate.CHAT_COMPLETION }
assertNotNull(chatTemplate)
assertEquals("Chat-based FIM", chatTemplate?.label)
assertEquals(listOf("\n\n", "```"), chatTemplate?.stopTokens)
}
@Test
fun `test chat completion template builds placeholder`() {
val template = InfillPromptTemplate.CHAT_COMPLETION
val infillRequest = InfillRequest.Builder("function test() {", "}", 0).build()
val result = template.buildPrompt(infillRequest)
assertEquals("CHAT_FIM_PLACEHOLDER", result)
}
@Test
fun `test chat based FIM request creation`() {
val infillRequest = InfillRequest.Builder(
"function calculateSum(a, b) {",
"return result;\n}",
0
).build()
val chatRequest = CodeCompletionRequestFactory.buildChatBasedFIMRequest(infillRequest)
assertNotNull(chatRequest)
assertEquals(2, chatRequest.messages.size)
val systemMessage = chatRequest.messages[0] as OpenAIChatCompletionStandardMessage
assertEquals("system", systemMessage.role)
assertTrue(systemMessage.content.contains("expert coding assistant"))
val userMessage = chatRequest.messages[1] as OpenAIChatCompletionStandardMessage
assertEquals("user", userMessage.role)
assertTrue(userMessage.content.contains("PREFIX:"))
assertTrue(userMessage.content.contains("SUFFIX:"))
assertTrue(userMessage.content.contains("function calculateSum(a, b) {"))
assertTrue(userMessage.content.contains("return result;"))
assertTrue(chatRequest.isStream)
assertEquals(128, chatRequest.maxTokens)
assertEquals(0.0, chatRequest.temperature)
}
@Test
fun `test custom request throws exception for chat completion template`() {
val infillRequest = InfillRequest.Builder("test", "test", 0).build()
assertThrows(IllegalArgumentException::class.java) {
CodeCompletionRequestFactory.buildCustomRequest(
infillRequest,
"http://test.com",
emptyMap(),
emptyMap(),
InfillPromptTemplate.CHAT_COMPLETION,
null
)
}
}
@Test
fun `test chat based FIM HTTP request creation`() {
val infillRequest = InfillRequest.Builder("test prefix", "test suffix", 0).build()
val headers = mapOf("Authorization" to "Bearer \$CUSTOM_SERVICE_API_KEY")
val body = mapOf<String, Any>(
"model" to "gpt-4.1",
"temperature" to 0.2,
"max_tokens" to 24,
"stream" to false
)
val httpRequest = CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest(
infillRequest,
"https://api.openai.com/v1/chat/completions",
headers,
body,
"test-api-key"
)
assertNotNull(httpRequest)
assertEquals("https://api.openai.com/v1/chat/completions", httpRequest.url.toString())
assertEquals("Bearer test-api-key", httpRequest.header("Authorization"))
assertEquals("application/json", httpRequest.body?.contentType()?.toString())
}
}