From 4998e91a293c2ef501043f8499afd19f5a3f2300 Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Mon, 4 Nov 2024 12:06:43 +0000 Subject: [PATCH] chore: oai compatibility for ollama chat completions --- gradle/libs.versions.toml | 2 +- .../completions/CompletionRequestService.java | 28 ++--- .../factory/OllamaRequestFactory.kt | 110 ++++-------------- ...lwindowChatCompletionRequestHandlerTest.kt | 30 ++--- 4 files changed, 45 insertions(+), 125 deletions(-) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 8e8d9f95..2f0aa137 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -12,7 +12,7 @@ jsoup = "1.17.2" jtokkit = "1.1.0" junit = "5.11.0" kotlin = "2.0.0" -llm-client = "0.8.25" +llm-client = "0.8.26" okio = "3.9.0" tree-sitter = "0.22.6a" diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index d9464884..ed494da2 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -16,9 +16,10 @@ import ee.carlrobert.codegpt.settings.service.google.GoogleSettings; import ee.carlrobert.llm.client.DeserializationUtil; import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest; import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest; +import ee.carlrobert.llm.client.codegpt.response.CodeGPTException; import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest; import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; -import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest; +import ee.carlrobert.llm.client.openai.completion.ErrorDetails; import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener; import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener; import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest; @@ -108,6 +109,8 @@ public final class CompletionRequestService { } case AZURE -> CompletionClientProvider.getAzureClient() .getChatCompletionAsync(completionRequest, eventListener); + case OLLAMA -> CompletionClientProvider.getOllamaClient() + .getChatCompletionAsync(completionRequest, eventListener); default -> throw new RuntimeException("Unknown service selected"); }; } @@ -134,11 +137,6 @@ public final class CompletionRequestService { .getModel(), eventListener); } - if (request instanceof OllamaChatCompletionRequest completionRequest) { - return CompletionClientProvider.getOllamaClient().getChatCompletionAsync( - completionRequest, - eventListener); - } if (request instanceof LlamaCompletionRequest completionRequest) { return CompletionClientProvider.getLlamaClient().getChatCompletionAsync( completionRequest, @@ -156,8 +154,14 @@ public final class CompletionRequestService { @Override public void run(@NotNull ProgressIndicator indicator) { indicator.setIndeterminate(true); - var response = CompletionRequestService.getInstance().getChatCompletion(request); - SwingUtilities.invokeLater(() -> eventListener.onComplete(new StringBuilder(response))); + try { + var response = CompletionRequestService.getInstance().getChatCompletion(request); + SwingUtilities.invokeLater( + () -> eventListener.onComplete(new StringBuilder(response))); + } catch (CodeGPTException e) { + SwingUtilities.invokeLater( + () -> eventListener.onError(new ErrorDetails(e.getDetail()), e)); + } } }); @@ -181,6 +185,8 @@ public final class CompletionRequestService { .getChatCompletion(completionRequest); case AZURE -> CompletionClientProvider.getAzureClient() .getChatCompletion(completionRequest); + case OLLAMA -> CompletionClientProvider.getOllamaClient() + .getChatCompletion(completionRequest); default -> throw new RuntimeException("Unknown service selected"); }; return tryExtractContent(response).orElseThrow(); @@ -217,12 +223,6 @@ public final class CompletionRequestService { .getContent().getParts().get(0) .getText(); } - if (request instanceof OllamaChatCompletionRequest completionRequest) { - return CompletionClientProvider.getOllamaClient() - .getChatCompletion(completionRequest) - .getMessage() - .getContent(); - } if (request instanceof LlamaCompletionRequest completionRequest) { return CompletionClientProvider.getLlamaClient() .getChatCompletion(completionRequest) diff --git a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt index 2914f595..27451715 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/completions/factory/OllamaRequestFactory.kt @@ -3,36 +3,26 @@ package ee.carlrobert.codegpt.completions.factory import com.intellij.openapi.components.service import ee.carlrobert.codegpt.completions.BaseRequestFactory import ee.carlrobert.codegpt.completions.ChatCompletionParameters -import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT -import ee.carlrobert.codegpt.completions.ConversationType +import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings -import ee.carlrobert.codegpt.settings.persona.PersonaSettings import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings -import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage -import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest -import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters -import java.io.IOException -import java.nio.file.Files -import java.nio.file.Path -import java.util.* +import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest +import ee.carlrobert.llm.completion.CompletionRequest class OllamaRequestFactory : BaseRequestFactory() { - override fun createChatRequest(params: ChatCompletionParameters): OllamaChatCompletionRequest { + override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest { + val model = service().state.model val configuration = service().state - val settings = service().state - return OllamaChatCompletionRequest.Builder( - settings.model, - buildOllamaMessages(params) - ) - .setStream(true) - .setOptions( - OllamaParameters.Builder() - .numPredict(configuration.maxTokens) - .temperature(configuration.temperature.toDouble()) - .build() + val requestBuilder: OpenAIChatCompletionRequest.Builder = + OpenAIChatCompletionRequest.Builder( + buildOpenAIMessages(model, params, params.referencedFiles) ) - .build() + .setModel(model) + .setMaxTokens(configuration.maxTokens) + .setStream(true) + .setTemperature(configuration.temperature.toDouble()) + return requestBuilder.build() } override fun createBasicCompletionRequest( @@ -40,71 +30,13 @@ class OllamaRequestFactory : BaseRequestFactory() { userPrompt: String, maxTokens: Int, stream: Boolean - ): OllamaChatCompletionRequest { - return OllamaChatCompletionRequest.Builder( - service().state.model, - listOf( - OllamaChatCompletionMessage("system", systemPrompt, null), - OllamaChatCompletionMessage("user", userPrompt, null) - ) + ): CompletionRequest { + val model = service().state.model + return OpenAIRequestFactory.createBasicCompletionRequest( + systemPrompt, + userPrompt, + model = model, + isStream = stream ) - .setStream(stream) - .build() } - - private fun buildOllamaMessages(params: ChatCompletionParameters): List { - val message = params.message - val messages = mutableListOf() - - when (params.conversationType) { - ConversationType.DEFAULT -> messages.add( - OllamaChatCompletionMessage("system", PersonaSettings.getSystemPrompt(), null) - ) - - ConversationType.FIX_COMPILE_ERRORS -> messages.add( - OllamaChatCompletionMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT, null) - ) - - else -> {} - } - - for (prevMessage in params.conversation.messages) { - if (params.retry && prevMessage.id == message.id) break - - prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath -> - try { - val imageBytes = Files.readAllBytes(Path.of(imagePath)) - val imageBase64 = Base64.getEncoder().encodeToString(imageBytes) - messages.add( - OllamaChatCompletionMessage( - "user", - prevMessage.prompt, - listOf(imageBase64) - ) - ) - } catch (e: IOException) { - throw RuntimeException(e) - } - } ?: run { - messages.add( - OllamaChatCompletionMessage( - "user", - getPromptWithFilesContext(params), - null - ) - ) - } - - messages.add(OllamaChatCompletionMessage("assistant", prevMessage.response, null)) - } - - if (params.imageMediaType != null && params.imageData != null) { - val imageBase64 = Base64.getEncoder().encodeToString(params.imageData) - messages.add(OllamaChatCompletionMessage("user", message.prompt, listOf(imageBase64))) - } else { - messages.add(OllamaChatCompletionMessage("user", message.prompt, null)) - } - - return messages - } -} \ No newline at end of file +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt index 3fc6bc64..9c52ab7b 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultToolwindowChatCompletionRequestHandlerTest.kt @@ -148,41 +148,29 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() { val message = Message("TEST_PROMPT") val conversation = ConversationService.getInstance().startConversation() expectOllama(NdJsonStreamHttpExchange { request: RequestEntity -> - assertThat(request.uri.path).isEqualTo("/api/chat") + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") assertThat(request.body) .extracting( "model", - "messages", - "options.num_predict", - "stream" + "messages" ) .containsExactly( HuggingFaceModel.LLAMA_3_8B_Q6_K.code, listOf( mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"), mapOf("role" to "user", "content" to "TEST_PROMPT") - ), - 99, - true + ) ) listOf( jsonMapResponse( - e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))), - e("done", false) + "choices", + jsonArray(jsonMap("delta", jsonMap("role", "assistant"))) ), - jsonMapResponse( - e("message", jsonMap(e("content", "lo"), e("role", "assistant"))), - e("done", false) - ), - jsonMapResponse( - e("message", jsonMap(e("content", "!"), e("role", "assistant"))), - e("done", false) - ), - jsonMapResponse( - e("message", jsonMap(e("content", ""), e("role", "assistant"))), - e("done", true) - ) + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))) ) }) val requestHandler =