From 8de72b330178fa97b0482e1dbb2829964f8b737a Mon Sep 17 00:00:00 2001 From: Phil Date: Thu, 25 Apr 2024 15:47:56 +0200 Subject: [PATCH] fix: use /infill for llama.cpp code-completions (#513) --- .../codegpt/completions/CompletionRequestService.java | 2 +- .../codecompletions/CodeCompletionRequestFactory.kt | 9 ++++----- .../codegpt/codecompletions/CodeCompletionServiceTest.kt | 6 +++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index edacf655..c2c63831 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -122,7 +122,7 @@ public final class CompletionRequestService { CodeCompletionRequestFactory.buildCustomRequest(requestDetails), new OpenAITextCompletionEventSourceListener(eventListener)); case LLAMA_CPP -> CompletionClientProvider.getLlamaClient() - .getChatCompletionAsync( + .getInfillAsync( CodeCompletionRequestFactory.buildLlamaRequest(requestDetails), eventListener); default -> diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt index d7723f0d..203a3add 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt @@ -12,6 +12,7 @@ import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings import ee.carlrobert.codegpt.settings.service.llama.LlamaSettingsState import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest +import ee.carlrobert.llm.client.llama.completion.LlamaInfillRequest import ee.carlrobert.llm.client.openai.completion.request.OpenAITextCompletionRequest import okhttp3.MediaType.Companion.toMediaType import okhttp3.Request @@ -59,16 +60,14 @@ object CodeCompletionRequestFactory { } @JvmStatic - fun buildLlamaRequest(details: InfillRequestDetails): LlamaCompletionRequest { + fun buildLlamaRequest(details: InfillRequestDetails): LlamaInfillRequest { val settings = LlamaSettings.getCurrentState() val promptTemplate = getLlamaInfillPromptTemplate(settings) - val prompt = promptTemplate.buildPrompt(details.prefix, details.suffix) - return LlamaCompletionRequest.Builder(prompt) + return LlamaInfillRequest(LlamaCompletionRequest.Builder(null) .setN_predict(settings.codeCompletionMaxTokens) .setStream(true) .setTemperature(0.4) - .setStop(promptTemplate.stopTokens) - .build() + .setStop(promptTemplate.stopTokens), details.prefix, details.suffix) } private fun getLlamaInfillPromptTemplate(settings: LlamaSettingsState): InfillPromptTemplate { diff --git a/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt index d92c00b9..5c2aeeed 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt @@ -35,11 +35,11 @@ class CodeCompletionServiceTest : IntegrationTest() { ${"z".repeat(247)} """.trimIndent() // 128 tokens expectLlama(StreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/completion") + assertThat(request.uri.path).isEqualTo("/infill") assertThat(request.method).isEqualTo("POST") assertThat(request.body) - .extracting("prompt") - .isEqualTo(InfillPromptTemplate.LLAMA.buildPrompt(prefix, suffix)) + .extracting("input_prefix", "input_suffix") + .containsExactly(prefix, suffix) listOf(jsonMapResponse(e("content", expectedCompletion), e("stop", true))) })