feat: set maxTokens and temperature for Ollama chat and codecompletion

This commit is contained in:
PhilKes 2024-05-25 14:20:08 +02:00 committed by Carl-Robert Linnupuu
parent b71cabdca6
commit 56cdc6d76b
4 changed files with 61 additions and 0 deletions

View file

@ -49,6 +49,7 @@ import ee.carlrobert.llm.client.google.models.GoogleModel;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionDetailedMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
@ -336,9 +337,15 @@ public class CompletionRequestProvider {
public OllamaChatCompletionRequest buildOllamaChatCompletionRequest(
CallParameters callParameters
) {
var configuration = ConfigurationSettings.getCurrentState();
var settings = ApplicationManager.getApplication().getService(OllamaSettings.class).getState();
return new OllamaChatCompletionRequest
.Builder(settings.getModel(), buildOllamaMessages(callParameters))
.setStream(true)
.setOptions(new OllamaParameters.Builder()
.numPredict(configuration.getMaxTokens())
.temperature(configuration.getTemperature())
.build())
.build();
}

View file

@ -116,6 +116,7 @@ object CodeCompletionRequestFactory {
OllamaParameters.Builder()
.stop(settings.fimTemplate.stopTokens)
.numPredict(getMaxTokens(details.prefix, details.suffix))
.temperature(0.4)
.build()
)
.setRaw(true)

View file

@ -6,8 +6,14 @@ import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.llm.client.http.RequestEntity
import ee.carlrobert.llm.client.http.exchange.NdJsonStreamHttpExchange
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
import ee.carlrobert.llm.client.ollama.OllamaClient
import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters
import ee.carlrobert.llm.client.util.JSONUtil.*
import ee.carlrobert.llm.completion.CompletionEventListener
import okhttp3.sse.EventSource
import org.apache.http.HttpHeaders
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
@ -168,6 +174,42 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
waitExpecting { "Hello!" == message.response }
}
fun testOllamaChatCompletionCall() {
useOllamaService()
ConfigurationSettings.getCurrentState().maxTokens = 99
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/api/chat")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages",
"options.num_predict",
"stream"
)
.containsExactly(
HuggingFaceModel.LLAMA_3_8B_Q6_K.code,
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")
),
99,
true
)
listOf(
jsonMapResponse("message", jsonMap(e("content", "Hel"), e("role", "assistant"))),
jsonMapResponse("message", jsonMap(e("content", "lo"), e("role", "assistant"))),
jsonMapResponse("message", jsonMap(e("content", "!"), e("role", "assistant")))
)
})
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
waitExpecting { "Hello!" == message.response }
}
fun testGoogleChatCompletionCall() {
useGoogleService()

View file

@ -2,6 +2,7 @@ package testsupport.mixin
import com.intellij.openapi.components.service
import com.intellij.testFramework.PlatformTestUtil
import ee.carlrobert.codegpt.completions.HuggingFaceModel
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.*
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
import ee.carlrobert.codegpt.settings.GeneralSettings
@ -10,6 +11,7 @@ import ee.carlrobert.codegpt.settings.service.azure.AzureSettings
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
import ee.carlrobert.codegpt.settings.service.google.GoogleSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import ee.carlrobert.llm.client.google.models.GoogleModel
import java.util.function.BooleanSupplier
@ -51,6 +53,15 @@ interface ShortcutsTestMixin {
LlamaSettings.getCurrentState().isCodeCompletionsEnabled = codeCompletionsEnabled
}
fun useOllamaService() {
GeneralSettings.getCurrentState().selectedService = ServiceType.OLLAMA
setCredential(OLLAMA_API_KEY, "TEST_API_KEY")
service<OllamaSettings>().state.apply {
model = HuggingFaceModel.LLAMA_3_8B_Q6_K.code
host = null
}
}
fun useGoogleService() {
GeneralSettings.getCurrentState().selectedService = ServiceType.GOOGLE
setCredential(GOOGLE_API_KEY, "TEST_API_KEY")