mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-12 05:51:28 +00:00
refactor: clean up completion request factory
This commit is contained in:
parent
88c54cfa07
commit
f26d15fc49
11 changed files with 57 additions and 63 deletions
|
|
@ -66,7 +66,7 @@ public final class CompletionRequestService {
|
|||
public String getLookupCompletion(String prompt) {
|
||||
return getChatCompletion(
|
||||
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
|
||||
.createLookupCompletionRequest(prompt));
|
||||
.createLookupRequest(prompt));
|
||||
}
|
||||
|
||||
public EventSource getCommitMessageAsync(
|
||||
|
|
@ -75,7 +75,7 @@ public final class CompletionRequestService {
|
|||
CompletionEventListener<String> eventListener) {
|
||||
return getChatCompletionAsync(
|
||||
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
|
||||
.createCommitMessageCompletionRequest(systemPrompt, gitDiff),
|
||||
.createCommitMessageRequest(systemPrompt, gitDiff),
|
||||
eventListener);
|
||||
}
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ public final class CompletionRequestService {
|
|||
var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText());
|
||||
return getChatCompletionAsync(
|
||||
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
|
||||
.createEditCodeCompletionRequest(input),
|
||||
.createEditCodeRequest(input),
|
||||
eventListener);
|
||||
}
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ public final class CompletionRequestService {
|
|||
CompletionEventListener<String> eventListener) {
|
||||
return getChatCompletionAsync(
|
||||
CompletionRequestFactory.getFactory(GeneralSettings.getSelectedService())
|
||||
.createChatCompletionRequest(callParameters),
|
||||
.createChatRequest(callParameters),
|
||||
eventListener);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,10 @@ import ee.carlrobert.codegpt.settings.service.ServiceType
|
|||
import ee.carlrobert.llm.completion.CompletionRequest
|
||||
|
||||
interface CompletionRequestFactory {
|
||||
fun createChatCompletionRequest(callParameters: CallParameters): CompletionRequest
|
||||
fun createEditCodeCompletionRequest(input: String): CompletionRequest
|
||||
fun createCommitMessageCompletionRequest(
|
||||
systemPrompt: String,
|
||||
gitDiff: String
|
||||
): CompletionRequest
|
||||
fun createLookupCompletionRequest(prompt: String): CompletionRequest
|
||||
fun createChatRequest(callParameters: CallParameters): CompletionRequest
|
||||
fun createEditCodeRequest(input: String): CompletionRequest
|
||||
fun createCommitMessageRequest(systemPrompt: String, gitDiff: String): CompletionRequest
|
||||
fun createLookupRequest(prompt: String): CompletionRequest
|
||||
|
||||
companion object {
|
||||
@JvmStatic
|
||||
|
|
@ -33,18 +30,18 @@ interface CompletionRequestFactory {
|
|||
}
|
||||
|
||||
abstract class BaseRequestFactory : CompletionRequestFactory {
|
||||
override fun createEditCodeCompletionRequest(input: String): CompletionRequest {
|
||||
override fun createEditCodeRequest(input: String): CompletionRequest {
|
||||
return createBasicCompletionRequest(EDIT_CODE_SYSTEM_PROMPT, input, true)
|
||||
}
|
||||
|
||||
override fun createCommitMessageCompletionRequest(
|
||||
override fun createCommitMessageRequest(
|
||||
systemPrompt: String,
|
||||
gitDiff: String
|
||||
): CompletionRequest {
|
||||
return createBasicCompletionRequest(systemPrompt, gitDiff, true)
|
||||
}
|
||||
|
||||
override fun createLookupCompletionRequest(prompt: String): CompletionRequest {
|
||||
override fun createLookupRequest(prompt: String): CompletionRequest {
|
||||
return createBasicCompletionRequest(GENERATE_METHOD_NAMES_SYSTEM_PROMPT, prompt)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import ee.carlrobert.llm.completion.CompletionRequest
|
|||
|
||||
class AzureRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
OpenAIChatCompletionRequest.Builder(buildOpenAIMessages(null, callParameters))
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import ee.carlrobert.llm.completion.CompletionRequest
|
|||
|
||||
class ClaudeRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatCompletionRequest(callParameters: CallParameters): ClaudeCompletionRequest {
|
||||
override fun createChatRequest(callParameters: CallParameters): ClaudeCompletionRequest {
|
||||
return ClaudeCompletionRequest().apply {
|
||||
model = service<AnthropicSettings>().state.model
|
||||
maxTokens = service<ConfigurationSettings>().state.maxTokens
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import ee.carlrobert.llm.client.openai.completion.request.RequestDocumentationDe
|
|||
|
||||
class CodeGPTRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class CustomOpenAIRequest(val request: Request) : CompletionRequest
|
|||
|
||||
class CustomOpenAIRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatCompletionRequest(callParameters: CallParameters): CustomOpenAIRequest {
|
||||
override fun createChatRequest(callParameters: CallParameters): CustomOpenAIRequest {
|
||||
val request = buildCustomOpenAIChatCompletionRequest(
|
||||
service<CustomServiceSettings>()
|
||||
.state
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import java.nio.file.Path
|
|||
|
||||
class GoogleRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatCompletionRequest(callParameters: CallParameters): GoogleCompletionRequest {
|
||||
override fun createChatRequest(callParameters: CallParameters): GoogleCompletionRequest {
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val messages = buildGoogleMessages(service<GoogleSettings>().state.model, callParameters)
|
||||
return GoogleCompletionRequest.Builder(messages)
|
||||
|
|
|
|||
|
|
@ -6,46 +6,28 @@ import ee.carlrobert.codegpt.completions.CallParameters
|
|||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.completions.llama.LlamaModel
|
||||
import ee.carlrobert.codegpt.completions.llama.PromptTemplate
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings.Companion.getState
|
||||
import ee.carlrobert.codegpt.settings.persona.PersonaSettings.Companion.getSystemPrompt
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
|
||||
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
|
||||
|
||||
class LlamaRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatCompletionRequest(callParameters: CallParameters): LlamaCompletionRequest {
|
||||
val settings = service<LlamaSettings>().state
|
||||
val promptTemplate = if (settings.isRunLocalServer) {
|
||||
if (settings.isUseCustomModel)
|
||||
settings.localModelPromptTemplate
|
||||
else
|
||||
LlamaModel.findByHuggingFaceModel(settings.huggingFaceModel).promptTemplate
|
||||
} else {
|
||||
settings.remoteModelPromptTemplate
|
||||
}
|
||||
|
||||
override fun createChatRequest(callParameters: CallParameters): LlamaCompletionRequest {
|
||||
val promptTemplate = getPromptTemplate()
|
||||
val systemPrompt =
|
||||
if (callParameters.conversationType == ConversationType.FIX_COMPILE_ERRORS)
|
||||
FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
else
|
||||
getSystemPrompt()
|
||||
|
||||
val prompt = promptTemplate.buildPrompt(
|
||||
systemPrompt,
|
||||
callParameters.message.prompt,
|
||||
callParameters.conversation.messages
|
||||
)
|
||||
val configuration = getState()
|
||||
return LlamaCompletionRequest.Builder(prompt)
|
||||
.setN_predict(configuration.maxTokens)
|
||||
.setTemperature(configuration.temperature.toDouble())
|
||||
.setTop_k(settings.topK)
|
||||
.setTop_p(settings.topP)
|
||||
.setMin_p(settings.minP)
|
||||
.setRepeat_penalty(settings.repeatPenalty)
|
||||
.setStop(promptTemplate.stopTokens)
|
||||
.build()
|
||||
|
||||
return buildLlamaRequest(prompt, promptTemplate.stopTokens, true)
|
||||
}
|
||||
|
||||
override fun createBasicCompletionRequest(
|
||||
|
|
@ -53,8 +35,16 @@ class LlamaRequestFactory : BaseRequestFactory() {
|
|||
userPrompt: String,
|
||||
stream: Boolean
|
||||
): LlamaCompletionRequest {
|
||||
val promptTemplate = getPromptTemplate()
|
||||
val finalPrompt =
|
||||
promptTemplate.buildPrompt(systemPrompt, userPrompt, listOf())
|
||||
|
||||
return buildLlamaRequest(finalPrompt, emptyList(), stream)
|
||||
}
|
||||
|
||||
private fun getPromptTemplate(): PromptTemplate {
|
||||
val settings = service<LlamaSettings>().state
|
||||
val promptTemplate = if (settings.isRunLocalServer) {
|
||||
return if (settings.isRunLocalServer) {
|
||||
if (settings.isUseCustomModel)
|
||||
settings.localModelPromptTemplate
|
||||
else
|
||||
|
|
@ -62,17 +52,24 @@ class LlamaRequestFactory : BaseRequestFactory() {
|
|||
} else {
|
||||
settings.remoteModelPromptTemplate
|
||||
}
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val finalPrompt =
|
||||
promptTemplate.buildPrompt(systemPrompt, userPrompt, listOf())
|
||||
return LlamaCompletionRequest.Builder(finalPrompt)
|
||||
.setN_predict(configuration.maxTokens)
|
||||
.setTemperature(configuration.temperature.toDouble())
|
||||
.setTop_k(settings.topK)
|
||||
.setTop_p(settings.topP)
|
||||
.setMin_p(settings.minP)
|
||||
}
|
||||
|
||||
private fun buildLlamaRequest(
|
||||
prompt: String,
|
||||
stopTokens: List<String>,
|
||||
stream: Boolean = false
|
||||
): LlamaCompletionRequest {
|
||||
val configSettings = service<ConfigurationSettings>().state
|
||||
val llamaSettings = service<LlamaSettings>().state
|
||||
return LlamaCompletionRequest.Builder(prompt)
|
||||
.setN_predict(configSettings.maxTokens)
|
||||
.setTemperature(configSettings.temperature.toDouble())
|
||||
.setTop_k(llamaSettings.topK)
|
||||
.setTop_p(llamaSettings.topP)
|
||||
.setMin_p(llamaSettings.minP)
|
||||
.setRepeat_penalty(llamaSettings.repeatPenalty)
|
||||
.setStop(stopTokens)
|
||||
.setStream(stream)
|
||||
.setRepeat_penalty(settings.repeatPenalty)
|
||||
.build()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import java.util.*
|
|||
|
||||
class OllamaRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatCompletionRequest(callParameters: CallParameters): OllamaChatCompletionRequest {
|
||||
override fun createChatRequest(callParameters: CallParameters): OllamaChatCompletionRequest {
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val settings = service<OllamaSettings>().state
|
||||
return OllamaChatCompletionRequest.Builder(
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ import java.nio.file.Path
|
|||
|
||||
class OpenAIRequestFactory : CompletionRequestFactory {
|
||||
|
||||
override fun createChatCompletionRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
override fun createChatRequest(callParameters: CallParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OpenAISettings>().state.model
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
|
|
@ -36,11 +36,11 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
return requestBuilder.build()
|
||||
}
|
||||
|
||||
override fun createEditCodeCompletionRequest(input: String): OpenAIChatCompletionRequest {
|
||||
override fun createEditCodeRequest(input: String): OpenAIChatCompletionRequest {
|
||||
return buildEditCodeRequest(input, service<OpenAISettings>().state.model)
|
||||
}
|
||||
|
||||
override fun createCommitMessageCompletionRequest(
|
||||
override fun createCommitMessageRequest(
|
||||
systemPrompt: String,
|
||||
gitDiff: String
|
||||
): CompletionRequest {
|
||||
|
|
@ -52,7 +52,7 @@ class OpenAIRequestFactory : CompletionRequestFactory {
|
|||
)
|
||||
}
|
||||
|
||||
override fun createLookupCompletionRequest(prompt: String): CompletionRequest {
|
||||
override fun createLookupRequest(prompt: String): CompletionRequest {
|
||||
return createBasicCompletionRequest(
|
||||
GENERATE_METHOD_NAMES_SYSTEM_PROMPT,
|
||||
prompt,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = OpenAIRequestFactory().createChatCompletionRequest(
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
|
|
@ -54,7 +54,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = OpenAIRequestFactory().createChatCompletionRequest(
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
|
|
@ -85,7 +85,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = OpenAIRequestFactory().createChatCompletionRequest(
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
|
|
@ -117,7 +117,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.addMessage(remainingMessage)
|
||||
conversation.discardTokenLimits()
|
||||
|
||||
val request = OpenAIRequestFactory().createChatCompletionRequest(
|
||||
val request = OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
|
|
@ -145,7 +145,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
|
|||
conversation.addMessage(createDummyMessage(1500))
|
||||
|
||||
assertThrows(TotalUsageExceededException::class.java) {
|
||||
OpenAIRequestFactory().createChatCompletionRequest(
|
||||
OpenAIRequestFactory().createChatRequest(
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue