Use enum value to store selected service (#265)

This commit is contained in:
Carl-Robert 2023-11-08 19:17:25 +02:00 committed by GitHub
parent ff60d1eab5
commit cfa5ff7776
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 166 additions and 233 deletions

View file

@ -5,6 +5,7 @@ import static java.util.stream.Collectors.toList;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
@ -43,17 +44,9 @@ public final class ConversationService {
var conversation = new Conversation();
conversation.setId(UUID.randomUUID());
conversation.setClientCode(clientCode);
if (settings.isUseYouService()) {
conversation.setModel("YouCode");
} else if (settings.isUseAzureService()) {
conversation.setModel(AzureSettingsState.getInstance().getModel());
} else if (settings.isUseOpenAIService()) {
conversation.setModel(OpenAISettingsState.getInstance().getModel());
} else {
conversation.setModel(LlamaSettingsState.getInstance().getHuggingFaceModel().getCode());
}
conversation.setCreatedOn(LocalDateTime.now());
conversation.setUpdatedOn(LocalDateTime.now());
conversation.setModel(getModelForSelectedService(settings.getSelectedService()));
return conversation;
}
@ -121,22 +114,9 @@ public final class ConversationService {
conversationState.setCurrentConversation(conversation);
}
private String getClientCode() {
var settings = SettingsState.getInstance();
if (settings.isUseOpenAIService()) {
return "chat.completion";
}
if (settings.isUseAzureService()) {
return "azure.chat.completion";
}
if (settings.isUseLlamaService()) {
return "llama.chat.completion";
}
return "you.chat.completion";
}
public Conversation startConversation() {
var conversation = createConversation(getClientCode());
var completionCode = SettingsState.getInstance().getSelectedService().getCompletionCode();
var conversation = createConversation(completionCode);
conversationState.setCurrentConversation(conversation);
addConversation(conversation);
return conversation;
@ -147,32 +127,6 @@ public final class ConversationService {
conversationState.setCurrentConversation(null);
}
public Optional<Conversation> getPreviousConversation() {
return tryGetNextOrPreviousConversation(true);
}
public Optional<Conversation> getNextConversation() {
return tryGetNextOrPreviousConversation(false);
}
private Optional<Conversation> tryGetNextOrPreviousConversation(boolean isPrevious) {
var currentConversation = ConversationsState.getCurrentConversation();
if (currentConversation != null) {
var sortedConversations = getSortedConversations();
for (int i = 0; i < sortedConversations.size(); i++) {
var conversation = sortedConversations.get(i);
if (conversation != null && conversation.getId().equals(currentConversation.getId())) {
// higher index indicates older conversation
var previousIndex = isPrevious ? i + 1 : i - 1;
if (isPrevious ? previousIndex < sortedConversations.size() : previousIndex != -1) {
return Optional.of(sortedConversations.get(previousIndex));
}
}
}
}
return Optional.empty();
}
public void deleteConversation(Conversation conversation) {
var iterator = conversationState.getConversationsMapping()
.get(conversation.getClientCode())
@ -205,4 +159,48 @@ public final class ConversationService {
conversation.discardTokenLimits();
saveConversation(conversation);
}
public Optional<Conversation> getPreviousConversation() {
return tryGetNextOrPreviousConversation(true);
}
public Optional<Conversation> getNextConversation() {
return tryGetNextOrPreviousConversation(false);
}
private Optional<Conversation> tryGetNextOrPreviousConversation(boolean isPrevious) {
var currentConversation = ConversationsState.getCurrentConversation();
if (currentConversation != null) {
var sortedConversations = getSortedConversations();
for (int i = 0; i < sortedConversations.size(); i++) {
var conversation = sortedConversations.get(i);
if (conversation != null && conversation.getId().equals(currentConversation.getId())) {
// higher index indicates older conversation
var previousIndex = isPrevious ? i + 1 : i - 1;
if (isPrevious ? previousIndex < sortedConversations.size() : previousIndex != -1) {
return Optional.of(sortedConversations.get(previousIndex));
}
}
}
}
return Optional.empty();
}
private static String getModelForSelectedService(ServiceType serviceType) {
switch (serviceType) {
case OPENAI:
return OpenAISettingsState.getInstance().getModel();
case AZURE:
return AzureSettingsState.getInstance().getModel();
case YOU:
return "YouCode";
case LLAMA_CPP:
var llamaSettings = LlamaSettingsState.getInstance();
return llamaSettings.isUseCustomModel() ?
llamaSettings.getCustomLlamaModelPath() :
llamaSettings.getHuggingFaceModel().getCode();
default:
throw new RuntimeException("Could not find corresponding service mapping");
}
}
}