refactor: clean up completion request factory

This commit is contained in:
Carl-Robert Linnupuu 2024-09-24 00:23:27 +03:00
parent 88c54cfa07
commit f26d15fc49
11 changed files with 57 additions and 63 deletions

View file

@ -66,7 +66,7 @@ public final class CompletionRequestService {
public String getLookupCompletion(String prompt) {
return getChatCompletion(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createLookupCompletionRequest(prompt));
.createLookupRequest(prompt));
}
public EventSource getCommitMessageAsync(
@ -75,7 +75,7 @@ public final class CompletionRequestService {
CompletionEventListener<String> eventListener) {
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createCommitMessageCompletionRequest(systemPrompt, gitDiff),
.createCommitMessageRequest(systemPrompt, gitDiff),
eventListener);
}
@ -85,7 +85,7 @@ public final class CompletionRequestService {
var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText());
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createEditCodeCompletionRequest(input),
.createEditCodeRequest(input),
eventListener);
}
@ -94,7 +94,7 @@ public final class CompletionRequestService {
CompletionEventListener<String> eventListener) {
return getChatCompletionAsync(
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
.createChatCompletionRequest(callParameters),
.createChatRequest(callParameters),
eventListener);
}

View file

@ -7,13 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.llm.completion.CompletionRequest
interface CompletionRequestFactory {
fun createChatCompletionRequest(callParameters: CallParameters): CompletionRequest
fun createEditCodeCompletionRequest(input: String): CompletionRequest
fun createCommitMessageCompletionRequest(
systemPrompt: String,
gitDiff: String
): CompletionRequest
fun createLookupCompletionRequest(prompt: String): CompletionRequest
fun createChatRequest(callParameters: CallParameters): CompletionRequest
fun createEditCodeRequest(input: String): CompletionRequest
fun createCommitMessageRequest(systemPrompt: String, gitDiff: String): CompletionRequest
fun createLookupRequest(prompt: String): CompletionRequest
companion object {
@JvmStatic
@ -33,18 +30,18 @@ interface CompletionRequestFactory {
}
abstract class BaseRequestFactory : CompletionRequestFactory {
override fun createEditCodeCompletionRequest(input: String): CompletionRequest {
override fun createEditCodeRequest(input: String): CompletionRequest {
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, true)
}
override fun createCommitMessageCompletionRequest(
override fun createCommitMessageRequest(
systemPrompt: String,
gitDiff: String
): CompletionRequest {
return createBasicCompletionRequest(systemPrompt, gitDiff, true)
}
override fun createLookupCompletionRequest(prompt: String): CompletionRequest {
override fun createLookupRequest(prompt: String): CompletionRequest {
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt)
}

View file

@ -10,7 +10,7 @@ import ee.carlrobert.llm.completion.CompletionRequest
class AzureRequestFactory : BaseRequestFactory() {
override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
val configuration = service<ConfigurationSettings>().state
val requestBuilder: OpenAIChatCompletionRequest.Builder =
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, callParameters))

View file

@ -11,7 +11,7 @@ import ee.carlrobert.llm.completion.CompletionRequest
class ClaudeRequestFactory : BaseRequestFactory() {
override fun createChatCompletionRequest(callParameters: CallParameters): ClaudeCompletionRequest {
override fun createChatRequest(callParameters: CallParameters): ClaudeCompletionRequest {
return ClaudeCompletionRequest().apply {
model = service<AnthropicSettings>().state.model
maxTokens = service<ConfigurationSettings>().state.maxTokens

View file

@ -11,7 +11,7 @@ import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDe
class CodeGPTRequestFactory : BaseRequestFactory() {
override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
val configuration = service<ConfigurationSettings>().state
val requestBuilder: OpenAIChatCompletionRequest.Builder =

View file

@ -19,7 +19,7 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest
class CustomOpenAIRequestFactory : BaseRequestFactory() {
override fun createChatCompletionRequest(callParameters: CallParameters): CustomOpenAIRequest {
override fun createChatRequest(callParameters: CallParameters): CustomOpenAIRequest {
val request = buildCustomOpenAIChatCompletionRequest(
service<CustomServiceSettings>()
.state

View file

@ -23,7 +23,7 @@ import java.nio.file.Path
class GoogleRequestFactory : BaseRequestFactory() {
override fun createChatCompletionRequest(callParameters: CallParameters): GoogleCompletionRequest {
override fun createChatRequest(callParameters: CallParameters): GoogleCompletionRequest {
val configuration = service<ConfigurationSettings>().state
val messages = buildGoogleMessages(service<GoogleSettings>().state.model, callParameters)
return GoogleCompletionRequest.Builder(messages)

View file

@ -6,46 +6,28 @@ import ee.carlrobert.codegpt.completions.CallParameters
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.completions.llama.LlamaModel
import ee.carlrobert.codegpt.completions.llama.PromptTemplate
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings.Companion.getState
import ee.carlrobert.codegpt.settings.persona.PersonaSettings.Companion.getSystemPrompt
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
class LlamaRequestFactory : BaseRequestFactory() {
override fun createChatCompletionRequest(callParameters: CallParameters): LlamaCompletionRequest {
val settings = service<LlamaSettings>().state
val promptTemplate = if (settings.isRunLocalServer) {
if (settings.isUseCustomModel)
settings.localModelPromptTemplate
else
LlamaModel.findByHuggingFaceModel(settings.huggingFaceModel).promptTemplate
} else {
settings.remoteModelPromptTemplate
}
override fun createChatRequest(callParameters: CallParameters): LlamaCompletionRequest {
val promptTemplate = getPromptTemplate()
val systemPrompt =
if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS)
FIX_COMPILE_ERRORS_SYSTEM_PROMPT
else
getSystemPrompt()
val prompt = promptTemplate.buildPrompt(
systemPrompt,
callParameters.message.prompt,
callParameters.conversation.messages
)
val configuration = getState()
return LlamaCompletionRequest.Builder(prompt)
.setN_predict(configuration.maxTokens)
.setTemperature(configuration.temperature.toDouble())
.setTop_k(settings.topK)
.setTop_p(settings.topP)
.setMin_p(settings.minP)
.setRepeat_penalty(settings.repeatPenalty)
.setStop(promptTemplate.stopTokens)
.build()
return buildLlamaRequest(prompt, promptTemplate.stopTokens, true)
}
override fun createBasicCompletionRequest(
@ -53,8 +35,16 @@ class LlamaRequestFactory : BaseRequestFactory() {
userPrompt: String,
stream: Boolean
): LlamaCompletionRequest {
val promptTemplate = getPromptTemplate()
val finalPrompt =
promptTemplate.buildPrompt(systemPrompt, userPrompt, listOf())
return buildLlamaRequest(finalPrompt, emptyList(), stream)
}
private fun getPromptTemplate(): PromptTemplate {
val settings = service<LlamaSettings>().state
val promptTemplate = if (settings.isRunLocalServer) {
return if (settings.isRunLocalServer) {
if (settings.isUseCustomModel)
settings.localModelPromptTemplate
else
@ -62,17 +52,24 @@ class LlamaRequestFactory : BaseRequestFactory() {
} else {
settings.remoteModelPromptTemplate
}
val configuration = service<ConfigurationSettings>().state
val finalPrompt =
promptTemplate.buildPrompt(systemPrompt, userPrompt, listOf())
return LlamaCompletionRequest.Builder(finalPrompt)
.setN_predict(configuration.maxTokens)
.setTemperature(configuration.temperature.toDouble())
.setTop_k(settings.topK)
.setTop_p(settings.topP)
.setMin_p(settings.minP)
}
private fun buildLlamaRequest(
prompt: String,
stopTokens: List<String>,
stream: Boolean = false
): LlamaCompletionRequest {
val configSettings = service<ConfigurationSettings>().state
val llamaSettings = service<LlamaSettings>().state
return LlamaCompletionRequest.Builder(prompt)
.setN_predict(configSettings.maxTokens)
.setTemperature(configSettings.temperature.toDouble())
.setTop_k(llamaSettings.topK)
.setTop_p(llamaSettings.topP)
.setMin_p(llamaSettings.minP)
.setRepeat_penalty(llamaSettings.repeatPenalty)
.setStop(stopTokens)
.setStream(stream)
.setRepeat_penalty(settings.repeatPenalty)
.build()
}
}

View file

@ -18,7 +18,7 @@ import java.util.*
class OllamaRequestFactory : BaseRequestFactory() {
override fun createChatCompletionRequest(callParameters: CallParameters): OllamaChatCompletionRequest {
override fun createChatRequest(callParameters: CallParameters): OllamaChatCompletionRequest {
val configuration = service<ConfigurationSettings>().state
val settings = service<OllamaSettings>().state
return OllamaChatCompletionRequest.Builder(

View file

@ -24,7 +24,7 @@ import java.nio.file.Path
class OpenAIRequestFactory : CompletionRequestFactory {
override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
val model = service<OpenAISettings>().state.model
val configuration = service<ConfigurationSettings>().state
val requestBuilder: OpenAIChatCompletionRequest.Builder =
@ -36,11 +36,11 @@ class OpenAIRequestFactory : CompletionRequestFactory {
return requestBuilder.build()
}
override fun createEditCodeCompletionRequest(input: String): OpenAIChatCompletionRequest {
override fun createEditCodeRequest(input: String): OpenAIChatCompletionRequest {
return buildEditCodeRequest(input, service<OpenAISettings>().state.model)
}
override fun createCommitMessageCompletionRequest(
override fun createCommitMessageRequest(
systemPrompt: String,
gitDiff: String
): CompletionRequest {
@ -52,7 +52,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
)
}
override fun createLookupCompletionRequest(prompt: String): CompletionRequest {
override fun createLookupRequest(prompt: String): CompletionRequest {
return createBasicCompletionRequest(
GENERATE_METHOD_NAMES_SYSTEM_PROMPT,
prompt,

View file

@ -23,7 +23,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = OpenAIRequestFactory().createChatCompletionRequest(
val request = OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
@ -54,7 +54,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = OpenAIRequestFactory().createChatCompletionRequest(
val request = OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
@ -85,7 +85,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = OpenAIRequestFactory().createChatCompletionRequest(
val request = OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
@ -117,7 +117,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(remainingMessage)
conversation.discardTokenLimits()
val request = OpenAIRequestFactory().createChatCompletionRequest(
val request = OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,
@ -145,7 +145,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(createDummyMessage(1500))
assertThrows(TotalUsageExceededException::class.java) {
OpenAIRequestFactory().createChatCompletionRequest(
OpenAIRequestFactory().createChatRequest(
CallParameters(
conversation,
ConversationType.DEFAULT,