diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index a10d99c2..5a3474bd 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -49,6 +49,7 @@ import ee.carlrobert.llm.client.google.models.GoogleModel; import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage; import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest; +import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters; import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel; import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionDetailedMessage; import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage; @@ -336,9 +337,15 @@ public class CompletionRequestProvider { public OllamaChatCompletionRequest buildOllamaChatCompletionRequest( CallParameters callParameters ) { + var configuration = ConfigurationSettings.getCurrentState(); var settings = ApplicationManager.getApplication().getService(OllamaSettings.class).getState(); return new OllamaChatCompletionRequest .Builder(settings.getModel(), buildOllamaMessages(callParameters)) + .setStream(true) + .setOptions(new OllamaParameters.Builder() + .numPredict(configuration.getMaxTokens()) + .temperature(configuration.getTemperature()) + .build()) .build(); } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt index 1b1a777a..75741fcd 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionRequestFactory.kt @@ -116,6 +116,7 @@ object CodeCompletionRequestFactory { OllamaParameters.Builder() .stop(settings.fimTemplate.stopTokens) .numPredict(getMaxTokens(details.prefix, details.suffix)) + .temperature(0.4) .build() ) .setRaw(true) diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt index 5e99ec2a..1f7c9f8d 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt @@ -6,8 +6,14 @@ import ee.carlrobert.codegpt.conversations.ConversationService import ee.carlrobert.codegpt.conversations.message.Message import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings import ee.carlrobert.llm.client.http.RequestEntity +import ee.carlrobert.llm.client.http.exchange.NdJsonStreamHttpExchange import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange +import ee.carlrobert.llm.client.ollama.OllamaClient +import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest +import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters import ee.carlrobert.llm.client.util.JSONUtil.* +import ee.carlrobert.llm.completion.CompletionEventListener +import okhttp3.sse.EventSource import org.apache.http.HttpHeaders import org.assertj.core.api.Assertions.assertThat import testsupport.IntegrationTest @@ -168,6 +174,42 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() { waitExpecting { "Hello!" == message.response } } + fun testOllamaChatCompletionCall() { + useOllamaService() + ConfigurationSettings.getCurrentState().maxTokens = 99 + ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT" + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + expectOllama(NdJsonStreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/api/chat") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages", + "options.num_predict", + "stream" + ) + .containsExactly( + HuggingFaceModel.LLAMA_3_8B_Q6_K.code, + listOf( + mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), + mapOf("role" to "user", "content" to "TEST_PROMPT") + ), + 99, + true + ) + listOf( + jsonMapResponse("message", jsonMap(e("content", "Hel"), e("role", "assistant"))), + jsonMapResponse("message", jsonMap(e("content", "lo"), e("role", "assistant"))), + jsonMapResponse("message", jsonMap(e("content", "!"), e("role", "assistant"))) + ) + }) + + requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false)) + waitExpecting { "Hello!" == message.response } + } fun testGoogleChatCompletionCall() { useGoogleService() diff --git a/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt b/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt index 1e325475..69a8156c 100644 --- a/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt +++ b/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt @@ -2,6 +2,7 @@ package testsupport.mixin import com.intellij.openapi.components.service import com.intellij.testFramework.PlatformTestUtil +import ee.carlrobert.codegpt.completions.HuggingFaceModel import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.* import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential import ee.carlrobert.codegpt.settings.GeneralSettings @@ -10,6 +11,7 @@ import ee.carlrobert.codegpt.settings.service.azure.AzureSettings import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings import ee.carlrobert.codegpt.settings.service.google.GoogleSettings import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings +import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings import ee.carlrobert.llm.client.google.models.GoogleModel import java.util.function.BooleanSupplier @@ -51,6 +53,15 @@ interface ShortcutsTestMixin { LlamaSettings.getCurrentState().isCodeCompletionsEnabled = codeCompletionsEnabled } + fun useOllamaService() { + GeneralSettings.getCurrentState().selectedService = ServiceType.OLLAMA + setCredential(OLLAMA_API_KEY, "TEST_API_KEY") + service().state.apply { + model = HuggingFaceModel.LLAMA_3_8B_Q6_K.code + host = null + } + } + fun useGoogleService() { GeneralSettings.getCurrentState().selectedService = ServiceType.GOOGLE setCredential(GOOGLE_API_KEY, "TEST_API_KEY")