mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-19 07:54:46 +00:00
feat: chat-based FIM code completion for Custom OpenAI providers (#1103)
* Implement FIM code completion with a chat models * Fix chat-based FIM * fix: build --------- Co-authored-by: Carl-Robert Linnupuu <carlrobertoh@gmail.com>
This commit is contained in:
parent
8f020a26eb
commit
d9973be131
5 changed files with 266 additions and 26 deletions
|
|
@ -16,6 +16,8 @@ import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings
|
|||
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAITextCompletionRequest
|
||||
import ee.carlrobert.service.GrpcCodeCompletionRequest
|
||||
import okhttp3.MediaType.Companion.toMediaType
|
||||
|
|
@ -52,6 +54,86 @@ object CodeCompletionRequestFactory {
|
|||
.build()
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun buildChatBasedFIMRequest(details: InfillRequest): OpenAIChatCompletionRequest {
|
||||
val systemMessage = OpenAIChatCompletionStandardMessage(
|
||||
"system",
|
||||
"You are a code completion assistant. Complete the code between the given prefix and suffix. " +
|
||||
"Return only the missing code that should be inserted, without any formatting, explanations, or markdown."
|
||||
)
|
||||
|
||||
val userMessage = OpenAIChatCompletionStandardMessage(
|
||||
"user",
|
||||
"<PREFIX>\n${details.prefix}\n</PREFIX>\n\n<SUFFIX>\n${details.suffix}\n</SUFFIX>\n\nComplete:"
|
||||
)
|
||||
|
||||
return OpenAIChatCompletionRequest.Builder(listOf(systemMessage, userMessage))
|
||||
.setModel(
|
||||
ModelSelectionService.getInstance().getModelForFeature(FeatureType.CODE_COMPLETION)
|
||||
)
|
||||
.setStream(true)
|
||||
.setMaxTokens(MAX_TOKENS)
|
||||
.setTemperature(0.0)
|
||||
.build()
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun buildChatBasedFIMHttpRequest(
|
||||
details: InfillRequest,
|
||||
url: String,
|
||||
headers: Map<String, String>,
|
||||
body: Map<String, Any>,
|
||||
credential: String?
|
||||
): Request {
|
||||
val requestBuilder = Request.Builder().url(url)
|
||||
|
||||
for (entry in headers.entries) {
|
||||
var value = entry.value
|
||||
if (credential != null && value.contains("\$CUSTOM_SERVICE_API_KEY")) {
|
||||
value = value.replace("\$CUSTOM_SERVICE_API_KEY", credential)
|
||||
}
|
||||
requestBuilder.addHeader(entry.key, value)
|
||||
}
|
||||
|
||||
// Create chat completion messages using the improved prompt template
|
||||
val systemMessage = mapOf<String, String>(
|
||||
"role" to "system",
|
||||
"content" to "You are a code completion assistant. Complete the code between the given prefix and suffix. " +
|
||||
"Return only the missing code that should be inserted, without any formatting, explanations, or markdown."
|
||||
)
|
||||
|
||||
val userMessage = mapOf<String, String>(
|
||||
"role" to "user",
|
||||
"content" to "<PREFIX>\n${details.prefix}\n</PREFIX>\n\n<SUFFIX>\n${details.suffix}\n</SUFFIX>\n\nComplete:"
|
||||
)
|
||||
|
||||
// Transform the custom body configuration, excluding completion-specific parameters
|
||||
val transformedBody = body.entries.mapNotNull { (key, value) ->
|
||||
when (key.lowercase()) {
|
||||
"messages" -> key to listOf(systemMessage, userMessage)
|
||||
// Exclude completion-specific parameters that don't apply to chat completions
|
||||
"prompt", "suffix" -> null
|
||||
else -> key to transformValue(value, InfillPromptTemplate.CHAT_COMPLETION, details)
|
||||
}
|
||||
}.toMap().toMutableMap()
|
||||
|
||||
// Ensure we have messages for chat completion
|
||||
if (!transformedBody.containsKey("messages")) {
|
||||
transformedBody["messages"] = listOf(systemMessage, userMessage)
|
||||
}
|
||||
|
||||
try {
|
||||
val jsonBody = ObjectMapper()
|
||||
.writerWithDefaultPrettyPrinter()
|
||||
.writeValueAsString(transformedBody)
|
||||
.toByteArray(StandardCharsets.UTF_8)
|
||||
.toRequestBody("application/json".toMediaType())
|
||||
return requestBuilder.post(jsonBody).build()
|
||||
} catch (e: JsonProcessingException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
fun buildCustomRequest(details: InfillRequest): Request {
|
||||
val activeService = service<CustomServicesSettings>()
|
||||
|
|
@ -78,6 +160,12 @@ object CodeCompletionRequestFactory {
|
|||
infillTemplate: InfillPromptTemplate,
|
||||
credential: String?
|
||||
): Request {
|
||||
// For chat-based FIM, we should not use this method
|
||||
// The routing logic in CodeCompletionService will handle it
|
||||
if (infillTemplate == InfillPromptTemplate.CHAT_COMPLETION) {
|
||||
throw IllegalArgumentException("Chat-based FIM should use buildChatBasedFIMRequest instead")
|
||||
}
|
||||
|
||||
val requestBuilder = Request.Builder().url(url)
|
||||
for (entry in headers.entries) {
|
||||
var value = entry.value
|
||||
|
|
|
|||
|
|
@ -3,11 +3,15 @@ package ee.carlrobert.codegpt.codecompletions
|
|||
import com.intellij.openapi.components.Service
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildChatBasedFIMRequest
|
||||
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest
|
||||
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildCustomRequest
|
||||
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildLlamaRequest
|
||||
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildOllamaRequest
|
||||
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory.buildOpenAIRequest
|
||||
import ee.carlrobert.codegpt.completions.CompletionClientProvider
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ModelSelectionService
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
|
|
@ -55,23 +59,47 @@ class CodeCompletionService(private val project: Project) {
|
|||
): EventSource {
|
||||
return when (val selectedService =
|
||||
ModelSelectionService.getInstance().getServiceForFeature(FeatureType.CODE_COMPLETION)) {
|
||||
OPENAI -> CompletionClientProvider.getOpenAIClient()
|
||||
.getCompletionAsync(buildOpenAIRequest(infillRequest), eventListener)
|
||||
OPENAI -> {
|
||||
val openAISettings = OpenAISettings.getCurrentState()
|
||||
// Check if user wants to use chat-based FIM (we'll add this setting later)
|
||||
// For now, default to traditional completion
|
||||
CompletionClientProvider.getOpenAIClient()
|
||||
.getCompletionAsync(buildOpenAIRequest(infillRequest), eventListener)
|
||||
}
|
||||
|
||||
CUSTOM_OPENAI -> {
|
||||
val settings = service<CustomServicesSettings>()
|
||||
.customServiceStateForFeatureType(FeatureType.CODE_COMPLETION)
|
||||
.codeCompletionSettings
|
||||
createFactory(
|
||||
CompletionClientProvider.getDefaultClientBuilder().build()
|
||||
).newEventSource(
|
||||
buildCustomRequest(infillRequest),
|
||||
if (settings.parseResponseAsChatCompletions) {
|
||||
val activeService = service<CustomServicesSettings>().state.active
|
||||
val customSettings = activeService.codeCompletionSettings
|
||||
val isChatBasedFIM = customSettings.infillTemplate == InfillPromptTemplate.CHAT_COMPLETION
|
||||
|
||||
if (isChatBasedFIM) {
|
||||
// Use chat completion endpoint for chat-based FIM with proper API key substitution
|
||||
val credential = getCredential(CredentialKey.CustomServiceApiKey(activeService.name.orEmpty()))
|
||||
createFactory(
|
||||
CompletionClientProvider.getDefaultClientBuilder().build()
|
||||
).newEventSource(
|
||||
buildChatBasedFIMHttpRequest(
|
||||
infillRequest,
|
||||
customSettings.url!!,
|
||||
customSettings.headers,
|
||||
customSettings.body,
|
||||
credential
|
||||
),
|
||||
OpenAIChatCompletionEventSourceListener(eventListener)
|
||||
} else {
|
||||
OpenAITextCompletionEventSourceListener(eventListener)
|
||||
}
|
||||
)
|
||||
)
|
||||
} else {
|
||||
// Use traditional completion endpoint
|
||||
createFactory(
|
||||
CompletionClientProvider.getDefaultClientBuilder().build()
|
||||
).newEventSource(
|
||||
buildCustomRequest(infillRequest),
|
||||
if (customSettings.parseResponseAsChatCompletions) {
|
||||
OpenAIChatCompletionEventSourceListener(eventListener)
|
||||
} else {
|
||||
OpenAITextCompletionEventSourceListener(eventListener)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
MISTRAL -> CompletionClientProvider.getMistralClient()
|
||||
|
|
|
|||
|
|
@ -155,6 +155,13 @@ enum class InfillPromptTemplate(val label: String, val stopTokens: List<String>?
|
|||
"[SUFFIX]${infillDetails.suffix}[PREFIX]${infillDetails.prefix}[MIDDLE]"
|
||||
return createDefaultMultiFilePrompt(infillDetails, infillPrompt)
|
||||
}
|
||||
},
|
||||
CHAT_COMPLETION("Chat-based FIM", listOf("\n\n", "```")) {
|
||||
override fun buildPrompt(infillDetails: InfillRequest): String {
|
||||
// This template is used for chat-based FIM completion
|
||||
// The actual prompt construction is handled in the request factory
|
||||
return "CHAT_FIM_PLACEHOLDER"
|
||||
}
|
||||
};
|
||||
|
||||
abstract fun buildPrompt(infillDetails: InfillRequest): String
|
||||
|
|
|
|||
|
|
@ -162,17 +162,35 @@ class CustomServiceCodeCompletionForm(
|
|||
}
|
||||
|
||||
private fun testConnection() {
|
||||
CompletionRequestService.getInstance().getCustomOpenAICompletionAsync(
|
||||
CodeCompletionRequestFactory.buildCustomRequest(
|
||||
InfillRequest.Builder("Hello", "!", 0).build(),
|
||||
urlField.text,
|
||||
tabbedPane.headers,
|
||||
tabbedPane.body,
|
||||
promptTemplateComboBox.selectedItem as InfillPromptTemplate,
|
||||
getApiKey.invoke()
|
||||
),
|
||||
TestConnectionEventListener()
|
||||
)
|
||||
val selectedTemplate = promptTemplateComboBox.selectedItem as InfillPromptTemplate
|
||||
val testRequest = InfillRequest.Builder("Hello", "!", 0).build()
|
||||
|
||||
if (selectedTemplate == InfillPromptTemplate.CHAT_COMPLETION) {
|
||||
// Use chat completion endpoint for testing
|
||||
CompletionRequestService.getInstance().getCustomOpenAIChatCompletionAsync(
|
||||
CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest(
|
||||
testRequest,
|
||||
urlField.text,
|
||||
tabbedPane.headers,
|
||||
tabbedPane.body,
|
||||
getApiKey.invoke()
|
||||
),
|
||||
TestConnectionEventListener()
|
||||
)
|
||||
} else {
|
||||
// Use traditional completion endpoint for testing
|
||||
CompletionRequestService.getInstance().getCustomOpenAICompletionAsync(
|
||||
CodeCompletionRequestFactory.buildCustomRequest(
|
||||
testRequest,
|
||||
urlField.text,
|
||||
tabbedPane.headers,
|
||||
tabbedPane.body,
|
||||
selectedTemplate,
|
||||
getApiKey.invoke()
|
||||
),
|
||||
TestConnectionEventListener()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal inner class TestConnectionEventListener : CompletionEventListener<String?> {
|
||||
|
|
@ -215,4 +233,4 @@ class CustomServiceCodeCompletionForm(
|
|||
.setDescription("<html><p>$description</p></html>")
|
||||
.installOn(promptTemplateHelpText)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
package ee.carlrobert.codegpt.codecompletions
|
||||
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ModelSelectionService
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.Assertions.*
|
||||
|
||||
class ChatBasedFIMTest {
|
||||
|
||||
@Test
|
||||
fun `test chat completion template is available`() {
|
||||
val templates = InfillPromptTemplate.values()
|
||||
val chatTemplate = templates.find { it == InfillPromptTemplate.CHAT_COMPLETION }
|
||||
|
||||
assertNotNull(chatTemplate)
|
||||
assertEquals("Chat-based FIM", chatTemplate?.label)
|
||||
assertEquals(listOf("\n\n", "```"), chatTemplate?.stopTokens)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `test chat completion template builds placeholder`() {
|
||||
val template = InfillPromptTemplate.CHAT_COMPLETION
|
||||
val infillRequest = InfillRequest.Builder("function test() {", "}", 0).build()
|
||||
|
||||
val result = template.buildPrompt(infillRequest)
|
||||
assertEquals("CHAT_FIM_PLACEHOLDER", result)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `test chat based FIM request creation`() {
|
||||
val infillRequest = InfillRequest.Builder(
|
||||
"function calculateSum(a, b) {",
|
||||
"return result;\n}",
|
||||
0
|
||||
).build()
|
||||
|
||||
val chatRequest = CodeCompletionRequestFactory.buildChatBasedFIMRequest(infillRequest)
|
||||
|
||||
assertNotNull(chatRequest)
|
||||
assertEquals(2, chatRequest.messages.size)
|
||||
|
||||
val systemMessage = chatRequest.messages[0] as OpenAIChatCompletionStandardMessage
|
||||
assertEquals("system", systemMessage.role)
|
||||
assertTrue(systemMessage.content.contains("expert coding assistant"))
|
||||
|
||||
val userMessage = chatRequest.messages[1] as OpenAIChatCompletionStandardMessage
|
||||
assertEquals("user", userMessage.role)
|
||||
assertTrue(userMessage.content.contains("PREFIX:"))
|
||||
assertTrue(userMessage.content.contains("SUFFIX:"))
|
||||
assertTrue(userMessage.content.contains("function calculateSum(a, b) {"))
|
||||
assertTrue(userMessage.content.contains("return result;"))
|
||||
|
||||
assertTrue(chatRequest.isStream)
|
||||
assertEquals(128, chatRequest.maxTokens)
|
||||
assertEquals(0.0, chatRequest.temperature)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `test custom request throws exception for chat completion template`() {
|
||||
val infillRequest = InfillRequest.Builder("test", "test", 0).build()
|
||||
|
||||
assertThrows(IllegalArgumentException::class.java) {
|
||||
CodeCompletionRequestFactory.buildCustomRequest(
|
||||
infillRequest,
|
||||
"http://test.com",
|
||||
emptyMap(),
|
||||
emptyMap(),
|
||||
InfillPromptTemplate.CHAT_COMPLETION,
|
||||
null
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `test chat based FIM HTTP request creation`() {
|
||||
val infillRequest = InfillRequest.Builder("test prefix", "test suffix", 0).build()
|
||||
val headers = mapOf("Authorization" to "Bearer \$CUSTOM_SERVICE_API_KEY")
|
||||
val body = mapOf<String, Any>(
|
||||
"model" to "gpt-4.1",
|
||||
"temperature" to 0.2,
|
||||
"max_tokens" to 24,
|
||||
"stream" to false
|
||||
)
|
||||
|
||||
val httpRequest = CodeCompletionRequestFactory.buildChatBasedFIMHttpRequest(
|
||||
infillRequest,
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers,
|
||||
body,
|
||||
"test-api-key"
|
||||
)
|
||||
|
||||
assertNotNull(httpRequest)
|
||||
assertEquals("https://api.openai.com/v1/chat/completions", httpRequest.url.toString())
|
||||
assertEquals("Bearer test-api-key", httpRequest.header("Authorization"))
|
||||
assertEquals("application/json", httpRequest.body?.contentType()?.toString())
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue