From f26d15fc4914c68ceb07c98b117c98dede13744e Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Tue, 24 Sep 2024 00:23:27 +0300 Subject: [PATCH] refactor: clean up completion request factory --- .../completions/CompletionRequestService.java | 8 +-- .../completions/CompletionRequestFactory.kt | 17 ++--- .../factory/AzureRequestFactory.kt | 2 +- .../factory/ClaudeRequestFactory.kt | 2 +- .../factory/CodeGPTRequestFactory.kt | 2 +- .../factory/CustomOpenAIRequestFactory.kt | 2 +- .../factory/GoogleRequestFactory.kt | 2 +- .../factory/LlamaRequestFactory.kt | 65 +++++++++---------- .../factory/OllamaRequestFactory.kt | 2 +- .../factory/OpenAIRequestFactory.kt | 8 +-- .../CompletionRequestProviderTest.kt | 10 +-- 11 files changed, 57 insertions(+), 63 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index f3573209..3ba8e44a 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -66,7 +66,7 @@ public final class CompletionRequestService { public String getLookupCompletion(String prompt) { return getChatCompletion( CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) - .createLookupCompletionRequest(prompt)); + .createLookupRequest(prompt)); } public EventSource getCommitMessageAsync( @@ -75,7 +75,7 @@ public final class CompletionRequestService { CompletionEventListener eventListener) { return getChatCompletionAsync( CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) - .createCommitMessageCompletionRequest(systemPrompt, gitDiff), + .createCommitMessageRequest(systemPrompt, gitDiff), eventListener); } @@ -85,7 +85,7 @@ public final class CompletionRequestService { var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText()); return getChatCompletionAsync( CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) - .createEditCodeCompletionRequest(input), + .createEditCodeRequest(input), eventListener); } @@ -94,7 +94,7 @@ public final class CompletionRequestService { CompletionEventListener eventListener) { return getChatCompletionAsync( CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService()) - .createChatCompletionRequest(callParameters), + .createChatRequest(callParameters), eventListener); } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt index e8ed1136..3aa6a112 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestFactory.kt @@ -7,13 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.llm.completion.CompletionRequest interface CompletionRequestFactory { - fun createChatCompletionRequest(callParameters: CallParameters): CompletionRequest - fun createEditCodeCompletionRequest(input: String): CompletionRequest - fun createCommitMessageCompletionRequest( - systemPrompt: String, - gitDiff: String - ): CompletionRequest - fun createLookupCompletionRequest(prompt: String): CompletionRequest + fun createChatRequest(callParameters: CallParameters): CompletionRequest + fun createEditCodeRequest(input: String): CompletionRequest + fun createCommitMessageRequest(systemPrompt: String, gitDiff: String): CompletionRequest + fun createLookupRequest(prompt: String): CompletionRequest companion object { @JvmStatic @@ -33,18 +30,18 @@ interface CompletionRequestFactory { } abstract class BaseRequestFactory : CompletionRequestFactory { - override fun createEditCodeCompletionRequest(input: String): CompletionRequest { + override fun createEditCodeRequest(input: String): CompletionRequest { return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, true) } - override fun createCommitMessageCompletionRequest( + override fun createCommitMessageRequest( systemPrompt: String, gitDiff: String ): CompletionRequest { return createBasicCompletionRequest(systemPrompt, gitDiff, true) } - override fun createLookupCompletionRequest(prompt: String): CompletionRequest { + override fun createLookupRequest(prompt: String): CompletionRequest { return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt) } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt index c7a18499..201ad51a 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/AzureRequestFactory.kt @@ -10,7 +10,7 @@ import ee.carlrobert.llm.completion.CompletionRequest class AzureRequestFactory : BaseRequestFactory() { - override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { val configuration = service().state val requestBuilder: OpenAIChatCompletionRequest.Builder = OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, callParameters)) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt index 66ee4b12..e4fbe81b 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/ClaudeRequestFactory.kt @@ -11,7 +11,7 @@ import ee.carlrobert.llm.completion.CompletionRequest class ClaudeRequestFactory : BaseRequestFactory() { - override fun createChatCompletionRequest(callParameters: CallParameters): ClaudeCompletionRequest { + override fun createChatRequest(callParameters: CallParameters): ClaudeCompletionRequest { return ClaudeCompletionRequest().apply { model = service().state.model maxTokens = service().state.maxTokens diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt index 7c6a4f84..d53efa56 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CodeGPTRequestFactory.kt @@ -11,7 +11,7 @@ import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDe class CodeGPTRequestFactory : BaseRequestFactory() { - override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { val model = service().state.chatCompletionSettings.model val configuration = service().state val requestBuilder: OpenAIChatCompletionRequest.Builder = diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt index 41418282..a118572e 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/CustomOpenAIRequestFactory.kt @@ -19,7 +19,7 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest class CustomOpenAIRequestFactory : BaseRequestFactory() { - override fun createChatCompletionRequest(callParameters: CallParameters): CustomOpenAIRequest { + override fun createChatRequest(callParameters: CallParameters): CustomOpenAIRequest { val request = buildCustomOpenAIChatCompletionRequest( service() .state diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt index f770de60..be66b669 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/GoogleRequestFactory.kt @@ -23,7 +23,7 @@ import java.nio.file.Path class GoogleRequestFactory : BaseRequestFactory() { - override fun createChatCompletionRequest(callParameters: CallParameters): GoogleCompletionRequest { + override fun createChatRequest(callParameters: CallParameters): GoogleCompletionRequest { val configuration = service().state val messages = buildGoogleMessages(service().state.model, callParameters) return GoogleCompletionRequest.Builder(messages) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt index 747f32dd..0ed0a2e0 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/LlamaRequestFactory.kt @@ -6,46 +6,28 @@ import ee.carlrobert.codegpt.completions.CallParameters import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT import ee.carlrobert.codegpt.completions.ConversationType import ee.carlrobert.codegpt.completions.llama.LlamaModel +import ee.carlrobert.codegpt.completions.llama.PromptTemplate import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings -import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings.Companion.getState import ee.carlrobert.codegpt.settings.persona.PersonaSettings.Companion.getSystemPrompt import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest class LlamaRequestFactory : BaseRequestFactory() { - override fun createChatCompletionRequest(callParameters: CallParameters): LlamaCompletionRequest { - val settings = service().state - val promptTemplate = if (settings.isRunLocalServer) { - if (settings.isUseCustomModel) - settings.localModelPromptTemplate - else - LlamaModel.findByHuggingFaceModel(settings.huggingFaceModel).promptTemplate - } else { - settings.remoteModelPromptTemplate - } - + override fun createChatRequest(callParameters: CallParameters): LlamaCompletionRequest { + val promptTemplate = getPromptTemplate() val systemPrompt = if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS) FIX_COMPILE_ERRORS_SYSTEM_PROMPT else getSystemPrompt() - val prompt = promptTemplate.buildPrompt( systemPrompt, callParameters.message.prompt, callParameters.conversation.messages ) - val configuration = getState() - return LlamaCompletionRequest.Builder(prompt) - .setN_predict(configuration.maxTokens) - .setTemperature(configuration.temperature.toDouble()) - .setTop_k(settings.topK) - .setTop_p(settings.topP) - .setMin_p(settings.minP) - .setRepeat_penalty(settings.repeatPenalty) - .setStop(promptTemplate.stopTokens) - .build() + + return buildLlamaRequest(prompt, promptTemplate.stopTokens, true) } override fun createBasicCompletionRequest( @@ -53,8 +35,16 @@ class LlamaRequestFactory : BaseRequestFactory() { userPrompt: String, stream: Boolean ): LlamaCompletionRequest { + val promptTemplate = getPromptTemplate() + val finalPrompt = + promptTemplate.buildPrompt(systemPrompt, userPrompt, listOf()) + + return buildLlamaRequest(finalPrompt, emptyList(), stream) + } + + private fun getPromptTemplate(): PromptTemplate { val settings = service().state - val promptTemplate = if (settings.isRunLocalServer) { + return if (settings.isRunLocalServer) { if (settings.isUseCustomModel) settings.localModelPromptTemplate else @@ -62,17 +52,24 @@ class LlamaRequestFactory : BaseRequestFactory() { } else { settings.remoteModelPromptTemplate } - val configuration = service().state - val finalPrompt = - promptTemplate.buildPrompt(systemPrompt, userPrompt, listOf()) - return LlamaCompletionRequest.Builder(finalPrompt) - .setN_predict(configuration.maxTokens) - .setTemperature(configuration.temperature.toDouble()) - .setTop_k(settings.topK) - .setTop_p(settings.topP) - .setMin_p(settings.minP) + } + + private fun buildLlamaRequest( + prompt: String, + stopTokens: List, + stream: Boolean = false + ): LlamaCompletionRequest { + val configSettings = service().state + val llamaSettings = service().state + return LlamaCompletionRequest.Builder(prompt) + .setN_predict(configSettings.maxTokens) + .setTemperature(configSettings.temperature.toDouble()) + .setTop_k(llamaSettings.topK) + .setTop_p(llamaSettings.topP) + .setMin_p(llamaSettings.minP) + .setRepeat_penalty(llamaSettings.repeatPenalty) + .setStop(stopTokens) .setStream(stream) - .setRepeat_penalty(settings.repeatPenalty) .build() } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt index 80fc9954..58942c06 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt @@ -18,7 +18,7 @@ import java.util.* class OllamaRequestFactory : BaseRequestFactory() { - override fun createChatCompletionRequest(callParameters: CallParameters): OllamaChatCompletionRequest { + override fun createChatRequest(callParameters: CallParameters): OllamaChatCompletionRequest { val configuration = service().state val settings = service().state return OllamaChatCompletionRequest.Builder( diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt index 33bcc479..f235f179 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OpenAIRequestFactory.kt @@ -24,7 +24,7 @@ import java.nio.file.Path class OpenAIRequestFactory : CompletionRequestFactory { - override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { + override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest { val model = service().state.model val configuration = service().state val requestBuilder: OpenAIChatCompletionRequest.Builder = @@ -36,11 +36,11 @@ class OpenAIRequestFactory : CompletionRequestFactory { return requestBuilder.build() } - override fun createEditCodeCompletionRequest(input: String): OpenAIChatCompletionRequest { + override fun createEditCodeRequest(input: String): OpenAIChatCompletionRequest { return buildEditCodeRequest(input, service().state.model) } - override fun createCommitMessageCompletionRequest( + override fun createCommitMessageRequest( systemPrompt: String, gitDiff: String ): CompletionRequest { @@ -52,7 +52,7 @@ class OpenAIRequestFactory : CompletionRequestFactory { ) } - override fun createLookupCompletionRequest(prompt: String): CompletionRequest { + override fun createLookupRequest(prompt: String): CompletionRequest { return createBasicCompletionRequest( GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt, diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt index d6051cc1..5b6a8d94 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt @@ -23,7 +23,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) - val request = OpenAIRequestFactory().createChatCompletionRequest( + val request = OpenAIRequestFactory().createChatRequest( CallParameters( conversation, ConversationType.DEFAULT, @@ -54,7 +54,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) - val request = OpenAIRequestFactory().createChatCompletionRequest( + val request = OpenAIRequestFactory().createChatRequest( CallParameters( conversation, ConversationType.DEFAULT, @@ -85,7 +85,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) - val request = OpenAIRequestFactory().createChatCompletionRequest( + val request = OpenAIRequestFactory().createChatRequest( CallParameters( conversation, ConversationType.DEFAULT, @@ -117,7 +117,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(remainingMessage) conversation.discardTokenLimits() - val request = OpenAIRequestFactory().createChatCompletionRequest( + val request = OpenAIRequestFactory().createChatRequest( CallParameters( conversation, ConversationType.DEFAULT, @@ -145,7 +145,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(createDummyMessage(1500)) assertThrows(TotalUsageExceededException::class.java) { - OpenAIRequestFactory().createChatCompletionRequest( + OpenAIRequestFactory().createChatRequest( CallParameters( conversation, ConversationType.DEFAULT,