chore: oai compatibility for ollama chat completions

This commit is contained in:
Carl-Robert Linnupuu 2024-11-04 12:06:43 +00:00
parent 38038653b2
commit 4998e91a29
4 changed files with 45 additions and 125 deletions

View file

@ -12,7 +12,7 @@ jsoup = "1.17.2"
jtokkit = "1.1.0"
junit = "5.11.0"
kotlin = "2.0.0"
llm-client = "0.8.25"
llm-client = "0.8.26"
okio = "3.9.0"
tree-sitter = "0.22.6a"

View file

@ -16,9 +16,10 @@ import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest;
import ee.carlrobert.llm.client.codegpt.response.CodeGPTException;
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
@ -108,6 +109,8 @@ public final class CompletionRequestService {
}
case AZURE -> CompletionClientProvider.getAzureClient()
.getChatCompletionAsync(completionRequest, eventListener);
case OLLAMA -> CompletionClientProvider.getOllamaClient()
.getChatCompletionAsync(completionRequest, eventListener);
default -> throw new RuntimeException("Unknown service selected");
};
}
@ -134,11 +137,6 @@ public final class CompletionRequestService {
.getModel(),
eventListener);
}
if (request instanceof OllamaChatCompletionRequest completionRequest) {
return CompletionClientProvider.getOllamaClient().getChatCompletionAsync(
completionRequest,
eventListener);
}
if (request instanceof LlamaCompletionRequest completionRequest) {
return CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
completionRequest,
@ -156,8 +154,14 @@ public final class CompletionRequestService {
@Override
public void run(@NotNull ProgressIndicator indicator) {
indicator.setIndeterminate(true);
var response = CompletionRequestService.getInstance().getChatCompletion(request);
SwingUtilities.invokeLater(() -> eventListener.onComplete(new StringBuilder(response)));
try {
var response = CompletionRequestService.getInstance().getChatCompletion(request);
SwingUtilities.invokeLater(
() -> eventListener.onComplete(new StringBuilder(response)));
} catch (CodeGPTException e) {
SwingUtilities.invokeLater(
() -> eventListener.onError(new ErrorDetails(e.getDetail()), e));
}
}
});
@ -181,6 +185,8 @@ public final class CompletionRequestService {
.getChatCompletion(completionRequest);
case AZURE -> CompletionClientProvider.getAzureClient()
.getChatCompletion(completionRequest);
case OLLAMA -> CompletionClientProvider.getOllamaClient()
.getChatCompletion(completionRequest);
default -> throw new RuntimeException("Unknown service selected");
};
return tryExtractContent(response).orElseThrow();
@ -217,12 +223,6 @@ public final class CompletionRequestService {
.getContent().getParts().get(0)
.getText();
}
if (request instanceof OllamaChatCompletionRequest completionRequest) {
return CompletionClientProvider.getOllamaClient()
.getChatCompletion(completionRequest)
.getMessage()
.getContent();
}
if (request instanceof LlamaCompletionRequest completionRequest) {
return CompletionClientProvider.getLlamaClient()
.getChatCompletion(completionRequest)

View file

@ -3,36 +3,26 @@ package ee.carlrobert.codegpt.completions.factory
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.BaseRequestFactory
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters
import java.io.IOException
import java.nio.file.Files
import java.nio.file.Path
import java.util.*
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest
import ee.carlrobert.llm.completion.CompletionRequest
class OllamaRequestFactory : BaseRequestFactory() {
override fun createChatRequest(params: ChatCompletionParameters): OllamaChatCompletionRequest {
override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest {
val model = service<OllamaSettings>().state.model
val configuration = service<ConfigurationSettings>().state
val settings = service<OllamaSettings>().state
return OllamaChatCompletionRequest.Builder(
settings.model,
buildOllamaMessages(params)
)
.setStream(true)
.setOptions(
OllamaParameters.Builder()
.numPredict(configuration.maxTokens)
.temperature(configuration.temperature.toDouble())
.build()
val requestBuilder: OpenAIChatCompletionRequest.Builder =
OpenAIChatCompletionRequest.Builder(
buildOpenAIMessages(model, params, params.referencedFiles)
)
.build()
.setModel(model)
.setMaxTokens(configuration.maxTokens)
.setStream(true)
.setTemperature(configuration.temperature.toDouble())
return requestBuilder.build()
}
override fun createBasicCompletionRequest(
@ -40,71 +30,13 @@ class OllamaRequestFactory : BaseRequestFactory() {
userPrompt: String,
maxTokens: Int,
stream: Boolean
): OllamaChatCompletionRequest {
return OllamaChatCompletionRequest.Builder(
service<OllamaSettings>().state.model,
listOf(
OllamaChatCompletionMessage("system", systemPrompt, null),
OllamaChatCompletionMessage("user", userPrompt, null)
)
): CompletionRequest {
val model = service<OllamaSettings>().state.model
return OpenAIRequestFactory.createBasicCompletionRequest(
systemPrompt,
userPrompt,
model = model,
isStream = stream
)
.setStream(stream)
.build()
}
private fun buildOllamaMessages(params: ChatCompletionParameters): List<OllamaChatCompletionMessage> {
val message = params.message
val messages = mutableListOf<OllamaChatCompletionMessage>()
when (params.conversationType) {
ConversationType.DEFAULT -> messages.add(
OllamaChatCompletionMessage("system", PersonaSettings.getSystemPrompt(), null)
)
ConversationType.FIX_COMPILE_ERRORS -> messages.add(
OllamaChatCompletionMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT, null)
)
else -> {}
}
for (prevMessage in params.conversation.messages) {
if (params.retry && prevMessage.id == message.id) break
prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath ->
try {
val imageBytes = Files.readAllBytes(Path.of(imagePath))
val imageBase64 = Base64.getEncoder().encodeToString(imageBytes)
messages.add(
OllamaChatCompletionMessage(
"user",
prevMessage.prompt,
listOf(imageBase64)
)
)
} catch (e: IOException) {
throw RuntimeException(e)
}
} ?: run {
messages.add(
OllamaChatCompletionMessage(
"user",
getPromptWithFilesContext(params),
null
)
)
}
messages.add(OllamaChatCompletionMessage("assistant", prevMessage.response, null))
}
if (params.imageMediaType != null && params.imageData != null) {
val imageBase64 = Base64.getEncoder().encodeToString(params.imageData)
messages.add(OllamaChatCompletionMessage("user", message.prompt, listOf(imageBase64)))
} else {
messages.add(OllamaChatCompletionMessage("user", message.prompt, null))
}
return messages
}
}
}

View file

@ -148,41 +148,29 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/api/chat")
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages",
"options.num_predict",
"stream"
"messages"
)
.containsExactly(
HuggingFaceModel.LLAMA_3_8B_Q6_K.code,
listOf(
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")
),
99,
true
)
)
listOf(
jsonMapResponse(
e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))),
e("done", false)
"choices",
jsonArray(jsonMap("delta", jsonMap("role", "assistant")))
),
jsonMapResponse(
e("message", jsonMap(e("content", "lo"), e("role", "assistant"))),
e("done", false)
),
jsonMapResponse(
e("message", jsonMap(e("content", "!"), e("role", "assistant"))),
e("done", false)
),
jsonMapResponse(
e("message", jsonMap(e("content", ""), e("role", "assistant"))),
e("done", true)
)
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))
)
})
val requestHandler =