mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-20 09:24:08 +00:00
feat: set maxTokens and temperature for Ollama chat and codecompletion
This commit is contained in:
parent
b71cabdca6
commit
56cdc6d76b
4 changed files with 61 additions and 0 deletions
|
|
@ -49,6 +49,7 @@ import ee.carlrobert.llm.client.google.models.GoogleModel;
|
|||
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage;
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters;
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionDetailedMessage;
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
|
||||
|
|
@ -336,9 +337,15 @@ public class CompletionRequestProvider {
|
|||
public OllamaChatCompletionRequest buildOllamaChatCompletionRequest(
|
||||
CallParameters callParameters
|
||||
) {
|
||||
var configuration = ConfigurationSettings.getCurrentState();
|
||||
var settings = ApplicationManager.getApplication().getService(OllamaSettings.class).getState();
|
||||
return new OllamaChatCompletionRequest
|
||||
.Builder(settings.getModel(), buildOllamaMessages(callParameters))
|
||||
.setStream(true)
|
||||
.setOptions(new OllamaParameters.Builder()
|
||||
.numPredict(configuration.getMaxTokens())
|
||||
.temperature(configuration.getTemperature())
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ object CodeCompletionRequestFactory {
|
|||
OllamaParameters.Builder()
|
||||
.stop(settings.fimTemplate.stopTokens)
|
||||
.numPredict(getMaxTokens(details.prefix, details.suffix))
|
||||
.temperature(0.4)
|
||||
.build()
|
||||
)
|
||||
.setRaw(true)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,14 @@ import ee.carlrobert.codegpt.conversations.ConversationService
|
|||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.llm.client.http.RequestEntity
|
||||
import ee.carlrobert.llm.client.http.exchange.NdJsonStreamHttpExchange
|
||||
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
|
||||
import ee.carlrobert.llm.client.ollama.OllamaClient
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.*
|
||||
import ee.carlrobert.llm.completion.CompletionEventListener
|
||||
import okhttp3.sse.EventSource
|
||||
import org.apache.http.HttpHeaders
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
|
|
@ -168,6 +174,42 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
|
|||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testOllamaChatCompletionCall() {
|
||||
useOllamaService()
|
||||
ConfigurationSettings.getCurrentState().maxTokens = 99
|
||||
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/api/chat")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages",
|
||||
"options.num_predict",
|
||||
"stream"
|
||||
)
|
||||
.containsExactly(
|
||||
HuggingFaceModel.LLAMA_3_8B_Q6_K.code,
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")
|
||||
),
|
||||
99,
|
||||
true
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse("message", jsonMap(e("content", "Hel"), e("role", "assistant"))),
|
||||
jsonMapResponse("message", jsonMap(e("content", "lo"), e("role", "assistant"))),
|
||||
jsonMapResponse("message", jsonMap(e("content", "!"), e("role", "assistant")))
|
||||
)
|
||||
})
|
||||
|
||||
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testGoogleChatCompletionCall() {
|
||||
useGoogleService()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package testsupport.mixin
|
|||
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.testFramework.PlatformTestUtil
|
||||
import ee.carlrobert.codegpt.completions.HuggingFaceModel
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.*
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings
|
||||
|
|
@ -10,6 +11,7 @@ import ee.carlrobert.codegpt.settings.service.azure.AzureSettings
|
|||
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
|
||||
import ee.carlrobert.codegpt.settings.service.google.GoogleSettings
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
|
||||
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
|
||||
import ee.carlrobert.llm.client.google.models.GoogleModel
|
||||
import java.util.function.BooleanSupplier
|
||||
|
|
@ -51,6 +53,15 @@ interface ShortcutsTestMixin {
|
|||
LlamaSettings.getCurrentState().isCodeCompletionsEnabled = codeCompletionsEnabled
|
||||
}
|
||||
|
||||
fun useOllamaService() {
|
||||
GeneralSettings.getCurrentState().selectedService = ServiceType.OLLAMA
|
||||
setCredential(OLLAMA_API_KEY, "TEST_API_KEY")
|
||||
service<OllamaSettings>().state.apply {
|
||||
model = HuggingFaceModel.LLAMA_3_8B_Q6_K.code
|
||||
host = null
|
||||
}
|
||||
}
|
||||
|
||||
fun useGoogleService() {
|
||||
GeneralSettings.getCurrentState().selectedService = ServiceType.GOOGLE
|
||||
setCredential(GOOGLE_API_KEY, "TEST_API_KEY")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue