mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-20 09:24:08 +00:00
chore: oai compatibility for ollama chat completions
This commit is contained in:
parent
38038653b2
commit
4998e91a29
4 changed files with 45 additions and 125 deletions
|
|
@ -12,7 +12,7 @@ jsoup = "1.17.2"
|
|||
jtokkit = "1.1.0"
|
||||
junit = "5.11.0"
|
||||
kotlin = "2.0.0"
|
||||
llm-client = "0.8.25"
|
||||
llm-client = "0.8.26"
|
||||
okio = "3.9.0"
|
||||
tree-sitter = "0.22.6a"
|
||||
|
||||
|
|
|
|||
|
|
@ -16,9 +16,10 @@ import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
|
|||
import ee.carlrobert.llm.client.DeserializationUtil;
|
||||
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
|
||||
import ee.carlrobert.llm.client.codegpt.request.chat.ChatCompletionRequest;
|
||||
import ee.carlrobert.llm.client.codegpt.response.CodeGPTException;
|
||||
import ee.carlrobert.llm.client.google.completion.GoogleCompletionRequest;
|
||||
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest;
|
||||
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
|
||||
|
|
@ -108,6 +109,8 @@ public final class CompletionRequestService {
|
|||
}
|
||||
case AZURE -> CompletionClientProvider.getAzureClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
case OLLAMA -> CompletionClientProvider.getOllamaClient()
|
||||
.getChatCompletionAsync(completionRequest, eventListener);
|
||||
default -> throw new RuntimeException("Unknown service selected");
|
||||
};
|
||||
}
|
||||
|
|
@ -134,11 +137,6 @@ public final class CompletionRequestService {
|
|||
.getModel(),
|
||||
eventListener);
|
||||
}
|
||||
if (request instanceof OllamaChatCompletionRequest completionRequest) {
|
||||
return CompletionClientProvider.getOllamaClient().getChatCompletionAsync(
|
||||
completionRequest,
|
||||
eventListener);
|
||||
}
|
||||
if (request instanceof LlamaCompletionRequest completionRequest) {
|
||||
return CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
|
||||
completionRequest,
|
||||
|
|
@ -156,8 +154,14 @@ public final class CompletionRequestService {
|
|||
@Override
|
||||
public void run(@NotNull ProgressIndicator indicator) {
|
||||
indicator.setIndeterminate(true);
|
||||
var response = CompletionRequestService.getInstance().getChatCompletion(request);
|
||||
SwingUtilities.invokeLater(() -> eventListener.onComplete(new StringBuilder(response)));
|
||||
try {
|
||||
var response = CompletionRequestService.getInstance().getChatCompletion(request);
|
||||
SwingUtilities.invokeLater(
|
||||
() -> eventListener.onComplete(new StringBuilder(response)));
|
||||
} catch (CodeGPTException e) {
|
||||
SwingUtilities.invokeLater(
|
||||
() -> eventListener.onError(new ErrorDetails(e.getDetail()), e));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -181,6 +185,8 @@ public final class CompletionRequestService {
|
|||
.getChatCompletion(completionRequest);
|
||||
case AZURE -> CompletionClientProvider.getAzureClient()
|
||||
.getChatCompletion(completionRequest);
|
||||
case OLLAMA -> CompletionClientProvider.getOllamaClient()
|
||||
.getChatCompletion(completionRequest);
|
||||
default -> throw new RuntimeException("Unknown service selected");
|
||||
};
|
||||
return tryExtractContent(response).orElseThrow();
|
||||
|
|
@ -217,12 +223,6 @@ public final class CompletionRequestService {
|
|||
.getContent().getParts().get(0)
|
||||
.getText();
|
||||
}
|
||||
if (request instanceof OllamaChatCompletionRequest completionRequest) {
|
||||
return CompletionClientProvider.getOllamaClient()
|
||||
.getChatCompletion(completionRequest)
|
||||
.getMessage()
|
||||
.getContent();
|
||||
}
|
||||
if (request instanceof LlamaCompletionRequest completionRequest) {
|
||||
return CompletionClientProvider.getLlamaClient()
|
||||
.getChatCompletion(completionRequest)
|
||||
|
|
|
|||
|
|
@ -3,36 +3,26 @@ package ee.carlrobert.codegpt.completions.factory
|
|||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.completions.BaseRequestFactory
|
||||
import ee.carlrobert.codegpt.completions.ChatCompletionParameters
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.completions.factory.OpenAIRequestFactory.Companion.buildOpenAIMessages
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
|
||||
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionMessage
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaChatCompletionRequest
|
||||
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters
|
||||
import java.io.IOException
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
import java.util.*
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest
|
||||
import ee.carlrobert.llm.completion.CompletionRequest
|
||||
|
||||
class OllamaRequestFactory : BaseRequestFactory() {
|
||||
|
||||
override fun createChatRequest(params: ChatCompletionParameters): OllamaChatCompletionRequest {
|
||||
override fun createChatRequest(params: ChatCompletionParameters): OpenAIChatCompletionRequest {
|
||||
val model = service<OllamaSettings>().state.model
|
||||
val configuration = service<ConfigurationSettings>().state
|
||||
val settings = service<OllamaSettings>().state
|
||||
return OllamaChatCompletionRequest.Builder(
|
||||
settings.model,
|
||||
buildOllamaMessages(params)
|
||||
)
|
||||
.setStream(true)
|
||||
.setOptions(
|
||||
OllamaParameters.Builder()
|
||||
.numPredict(configuration.maxTokens)
|
||||
.temperature(configuration.temperature.toDouble())
|
||||
.build()
|
||||
val requestBuilder: OpenAIChatCompletionRequest.Builder =
|
||||
OpenAIChatCompletionRequest.Builder(
|
||||
buildOpenAIMessages(model, params, params.referencedFiles)
|
||||
)
|
||||
.build()
|
||||
.setModel(model)
|
||||
.setMaxTokens(configuration.maxTokens)
|
||||
.setStream(true)
|
||||
.setTemperature(configuration.temperature.toDouble())
|
||||
return requestBuilder.build()
|
||||
}
|
||||
|
||||
override fun createBasicCompletionRequest(
|
||||
|
|
@ -40,71 +30,13 @@ class OllamaRequestFactory : BaseRequestFactory() {
|
|||
userPrompt: String,
|
||||
maxTokens: Int,
|
||||
stream: Boolean
|
||||
): OllamaChatCompletionRequest {
|
||||
return OllamaChatCompletionRequest.Builder(
|
||||
service<OllamaSettings>().state.model,
|
||||
listOf(
|
||||
OllamaChatCompletionMessage("system", systemPrompt, null),
|
||||
OllamaChatCompletionMessage("user", userPrompt, null)
|
||||
)
|
||||
): CompletionRequest {
|
||||
val model = service<OllamaSettings>().state.model
|
||||
return OpenAIRequestFactory.createBasicCompletionRequest(
|
||||
systemPrompt,
|
||||
userPrompt,
|
||||
model = model,
|
||||
isStream = stream
|
||||
)
|
||||
.setStream(stream)
|
||||
.build()
|
||||
}
|
||||
|
||||
private fun buildOllamaMessages(params: ChatCompletionParameters): List<OllamaChatCompletionMessage> {
|
||||
val message = params.message
|
||||
val messages = mutableListOf<OllamaChatCompletionMessage>()
|
||||
|
||||
when (params.conversationType) {
|
||||
ConversationType.DEFAULT -> messages.add(
|
||||
OllamaChatCompletionMessage("system", PersonaSettings.getSystemPrompt(), null)
|
||||
)
|
||||
|
||||
ConversationType.FIX_COMPILE_ERRORS -> messages.add(
|
||||
OllamaChatCompletionMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT, null)
|
||||
)
|
||||
|
||||
else -> {}
|
||||
}
|
||||
|
||||
for (prevMessage in params.conversation.messages) {
|
||||
if (params.retry && prevMessage.id == message.id) break
|
||||
|
||||
prevMessage.imageFilePath?.takeIf { it.isNotEmpty() }?.let { imagePath ->
|
||||
try {
|
||||
val imageBytes = Files.readAllBytes(Path.of(imagePath))
|
||||
val imageBase64 = Base64.getEncoder().encodeToString(imageBytes)
|
||||
messages.add(
|
||||
OllamaChatCompletionMessage(
|
||||
"user",
|
||||
prevMessage.prompt,
|
||||
listOf(imageBase64)
|
||||
)
|
||||
)
|
||||
} catch (e: IOException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
} ?: run {
|
||||
messages.add(
|
||||
OllamaChatCompletionMessage(
|
||||
"user",
|
||||
getPromptWithFilesContext(params),
|
||||
null
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
messages.add(OllamaChatCompletionMessage("assistant", prevMessage.response, null))
|
||||
}
|
||||
|
||||
if (params.imageMediaType != null && params.imageData != null) {
|
||||
val imageBase64 = Base64.getEncoder().encodeToString(params.imageData)
|
||||
messages.add(OllamaChatCompletionMessage("user", message.prompt, listOf(imageBase64)))
|
||||
} else {
|
||||
messages.add(OllamaChatCompletionMessage("user", message.prompt, null))
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -148,41 +148,29 @@ class DefaultToolwindowChatCompletionRequestHandlerTest : IntegrationTest() {
|
|||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
expectOllama(NdJsonStreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/api/chat")
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages",
|
||||
"options.num_predict",
|
||||
"stream"
|
||||
"messages"
|
||||
)
|
||||
.containsExactly(
|
||||
HuggingFaceModel.LLAMA_3_8B_Q6_K.code,
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")
|
||||
),
|
||||
99,
|
||||
true
|
||||
)
|
||||
)
|
||||
listOf(
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "Hel"), e("role", "assistant"))),
|
||||
e("done", false)
|
||||
"choices",
|
||||
jsonArray(jsonMap("delta", jsonMap("role", "assistant")))
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "lo"), e("role", "assistant"))),
|
||||
e("done", false)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", "!"), e("role", "assistant"))),
|
||||
e("done", false)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("message", jsonMap(e("content", ""), e("role", "assistant"))),
|
||||
e("done", true)
|
||||
)
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))
|
||||
)
|
||||
})
|
||||
val requestHandler =
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue