feat: Visualize downloaded models (#543)

* feat: Visualize downloaded models

* Simplify GeneralSettings access
This commit is contained in:
Rene Leonhardt 2024-05-13 09:48:55 +02:00 committed by GitHub
parent fcd0808111
commit 9bd7e6e83a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 135 additions and 89 deletions

View file

@ -62,14 +62,12 @@ public class GenerateGitCommitMessageAction extends AnAction {
@Override
public void update(@NotNull AnActionEvent event) {
var commitWorkflowUi = event.getData(VcsDataKeys.COMMIT_WORKFLOW_UI);
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
if (selectedService == YOU || commitWorkflowUi == null) {
if (GeneralSettings.isSelected(YOU) || commitWorkflowUi == null) {
event.getPresentation().setVisible(false);
return;
}
var callAllowed = CompletionRequestService.isRequestAllowed(
GeneralSettings.getCurrentState().getSelectedService());
var callAllowed = CompletionRequestService.isRequestAllowed();
event.getPresentation().setEnabled(callAllowed
&& new CommitWorkflowChanges(commitWorkflowUi).isFilesSelected());
event.getPresentation().setText(CodeGPTBundle.get(callAllowed

View file

@ -449,8 +449,7 @@ public class CompletionRequestProvider {
CallParameters callParameters) {
var messages = buildOpenAIMessages(callParameters);
if (model == null
|| GeneralSettings.getCurrentState().getSelectedService() == ServiceType.YOU) {
if (model == null || GeneralSettings.isSelected(ServiceType.YOU)) {
return messages;
}

View file

@ -81,7 +81,7 @@ public final class CompletionRequestService {
CompletionEventListener<String> eventListener) {
var application = ApplicationManager.getApplication();
var requestProvider = new CompletionRequestProvider(callParameters.getConversation());
return switch (GeneralSettings.getCurrentState().getSelectedService()) {
return switch (GeneralSettings.getSelectedService()) {
case CODEGPT -> CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync(
requestProvider.buildOpenAIChatCompletionRequest(
application.getService(CodeGPTServiceSettings.class)
@ -141,7 +141,7 @@ public final class CompletionRequestService {
new OpenAIChatCompletionStandardMessage("system", systemPrompt),
new OpenAIChatCompletionStandardMessage("user", gitDiff)))
.setModel(OpenAISettings.getCurrentState().getModel());
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
var selectedService = GeneralSettings.getSelectedService();
switch (selectedService) {
case CODEGPT:
CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync(
@ -243,7 +243,7 @@ public final class CompletionRequestService {
public Optional<String> getLookupCompletion(String prompt) {
var openaiRequest = CompletionRequestProvider.buildOpenAILookupCompletionRequest(prompt);
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
var selectedService = GeneralSettings.getSelectedService();
switch (selectedService) {
case CODEGPT:
var model = ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class)
@ -273,8 +273,12 @@ public final class CompletionRequestService {
}
}
public boolean isRequestAllowed() {
return isRequestAllowed(GeneralSettings.getCurrentState().getSelectedService());
public boolean isAllowed() {
return isRequestAllowed();
}
public static boolean isRequestAllowed() {
return isRequestAllowed(GeneralSettings.getSelectedService());
}
public static boolean isRequestAllowed(ServiceType serviceType) {

View file

@ -1,9 +1,12 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.llama.LlamaModel.getDownloadedMarker;
import static ee.carlrobert.codegpt.completions.llama.LlamaModel.getLlamaModelsPath;
import static java.lang.String.format;
import java.net.MalformedURLException;
import java.net.URL;
import org.jetbrains.annotations.NotNull;
public enum HuggingFaceModel {
@ -179,6 +182,14 @@ public enum HuggingFaceModel {
}
public String getQuantizationLabel() {
return format("%d-bit precision", quantization);
return format("%s %d-bit precision", downloaded(), quantization);
}
public boolean isDownloaded() {
return getLlamaModelsPath().resolve(fileName).toFile().exists();
}
private @NotNull String downloaded() {
return getDownloadedMarker(isDownloaded());
}
}

View file

@ -5,8 +5,15 @@ import static java.util.stream.Collectors.toSet;
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate;
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.function.BiConsumer;
import org.jetbrains.annotations.NotNull;
public enum LlamaModel {
@ -174,18 +181,37 @@ public enum LlamaModel {
}
public static @NotNull LlamaModel findByHuggingFaceModel(HuggingFaceModel huggingFaceModel) {
for (var llamaModel : LlamaModel.values()) {
if (llamaModel.getHuggingFaceModels().contains(huggingFaceModel)) {
return llamaModel;
}
}
return Arrays.stream(LlamaModel.values())
.filter(model -> model.getHuggingFaceModels().contains(huggingFaceModel))
.findFirst()
.orElseThrow(() -> new RuntimeException("Unable to find correct LLM"));
}
throw new RuntimeException("Unable to find correct LLM");
public @NotNull List<HuggingFaceModel> filterSelectedModelsBySize(ModelSize selectedModelSize) {
return selectedModelSize != null ? getHuggingFaceModels().stream()
.filter(model -> selectedModelSize.size() == model.getParameterSize())
.toList() : List.of();
}
public boolean anyDownloaded() {
return huggingFaceModels.stream().anyMatch(HuggingFaceModel::isDownloaded);
}
public String getDownloadedMarker() {
return getDownloadedMarker(anyDownloaded());
}
public static String getDownloadedMarker(boolean downloaded) {
return downloaded ? "" : "\u2001";
}
public static @NotNull Path getLlamaModelsPath() {
return Paths.get(System.getProperty("user.home"), ".codegpt/models/gguf");
}
@Override
public String toString() {
return String.join(" ", label, getFormattedModelSizeRange());
return String.join(" ", getDownloadedMarker(), label, getFormattedModelSizeRange());
}
public String getLabel() {
@ -218,12 +244,37 @@ public enum LlamaModel {
return format("(%dB - %dB)", Collections.min(parameters), Collections.max(parameters));
}
public List<Integer> getSortedUniqueModelSizes() {
public List<ModelSize> getSortedUniqueModelSizes() {
return huggingFaceModels.stream()
.map(HuggingFaceModel::getParameterSize)
.collect(toSet())
.stream()
.sorted()
.toList();
.map(hfm -> new ModelSize(hfm.getParameterSize(), hfm.isDownloaded()))
.sorted()
.collect(LinkedHashSet::new, ModelSize.skipSameSize(), Set::addAll)
.stream().toList();
}
public record ModelSize(int size, boolean downloaded) implements Comparable<ModelSize> {
// Sort by size, but downloaded comes first: [ 7B, 13B, 13B, 34B]
private static final Comparator<ModelSize> sizeDownloadedFirst = Comparator
.comparing(ModelSize::size)
.thenComparing(Comparator.comparing(ModelSize::downloaded).reversed());
@Override
public int compareTo(@NotNull ModelSize other) {
return sizeDownloadedFirst.compare(this, other);
}
private static @NotNull BiConsumer<Set<ModelSize>, ModelSize> skipSameSize() {
return (s, e) -> {
if (s.stream().noneMatch(v -> v.size == e.size)) {
s.add(e);
}
};
}
@Override
public String toString() {
return "%s %dB".formatted(getDownloadedMarker(downloaded), size);
}
}
}

View file

@ -49,8 +49,7 @@ public final class ConversationService {
conversation.setClientCode(clientCode);
conversation.setCreatedOn(LocalDateTime.now());
conversation.setUpdatedOn(LocalDateTime.now());
conversation.setModel(getModelForSelectedService(
GeneralSettings.getCurrentState().getSelectedService()));
conversation.setModel(getModelForSelectedService(GeneralSettings.getSelectedService()));
return conversation;
}
@ -113,7 +112,7 @@ public final class ConversationService {
}
public Conversation startConversation() {
var completionCode = GeneralSettings.getCurrentState().getSelectedService().getCompletionCode();
var completionCode = GeneralSettings.getSelectedService().getCompletionCode();
var conversation = createConversation(completionCode);
conversationState.setCurrentConversation(conversation);
addConversation(conversation);

View file

@ -41,6 +41,14 @@ public class GeneralSettings implements PersistentStateComponent<GeneralSettings
return ApplicationManager.getApplication().getService(GeneralSettings.class);
}
public static ServiceType getSelectedService() {
return getCurrentState().getSelectedService();
}
public static boolean isSelected(ServiceType serviceType) {
return getSelectedService() == serviceType;
}
public void sync(Conversation conversation) {
var clientCode = conversation.getClientCode();
if ("chat.completion".equals(clientCode)) {
@ -105,7 +113,8 @@ public class GeneralSettings implements PersistentStateComponent<GeneralSettings
var huggingFaceModel = llamaSettings.getHuggingFaceModel();
var llamaModel = LlamaModel.findByHuggingFaceModel(huggingFaceModel);
return String.format(
"%s %dB (Q%d)",
"%s %s %dB (Q%d)",
llamaModel.getDownloadedMarker(),
llamaModel.getLabel(),
huggingFaceModel.getParameterSize(),
huggingFaceModel.getQuantization());

View file

@ -27,6 +27,7 @@ import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate;
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
import ee.carlrobert.codegpt.completions.llama.LlamaModel.ModelSize;
import ee.carlrobert.codegpt.completions.llama.LlamaServerAgent;
import ee.carlrobert.codegpt.completions.llama.PromptTemplate;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
@ -113,9 +114,7 @@ public class LlamaModelPreferencesForm {
var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class);
huggingFaceModelComboBox.setEnabled(!llamaServerAgent.isServerRunning());
var modelSizeComboBoxModel = new DefaultComboBoxModel<ModelSize>();
var initialModelSizes = llamaModel.getSortedUniqueModelSizes().stream()
.map(ModelSize::new)
.toList();
var initialModelSizes = llamaModel.getSortedUniqueModelSizes();
modelSizeComboBoxModel.addAll(initialModelSizes);
modelSizeComboBoxModel.setSelectedItem(initialModelSizes.get(0));
var modelComboBoxModel = new EnumComboBoxModel<>(LlamaModel.class);
@ -315,22 +314,15 @@ public class LlamaModelPreferencesForm {
comboBox.setSelectedItem(llamaModel);
comboBox.addItemListener(e -> {
var selectedModel = (LlamaModel) e.getItem();
var modelSizes = selectedModel.getSortedUniqueModelSizes().stream()
.map(ModelSize::new)
.toList();
var modelSizes = selectedModel.getSortedUniqueModelSizes();
modelSizeComboBoxModel.removeAllElements();
modelSizeComboBoxModel.addAll(modelSizes);
modelSizeComboBoxModel.setSelectedItem(modelSizes.get(0));
modelSizeComboBox.setEnabled(modelSizes.size() > 1);
var huggingFaceModels = selectedModel.getHuggingFaceModels().stream()
.filter(model -> {
var selectedModelSize = (ModelSize) modelSizeComboBoxModel.getSelectedItem();
return selectedModelSize != null
&& selectedModelSize.size() == model.getParameterSize();
})
.toList();
var huggingFaceModels = selectedModel.filterSelectedModelsBySize(
(ModelSize) modelSizeComboBoxModel.getSelectedItem());
huggingFaceComboBoxModel.removeAllElements();
huggingFaceComboBoxModel.addAll(huggingFaceModels);
@ -348,13 +340,8 @@ public class LlamaModelPreferencesForm {
comboBox.setSelectedItem(modelSizeComboBoxModel.getSelectedItem());
comboBox.addItemListener(e -> {
var selectedModel = llamaModelComboBoxModel.getSelectedItem();
var models = selectedModel.getHuggingFaceModels().stream()
.filter(model -> {
var selectedModelSize = (ModelSize) modelSizeComboBoxModel.getSelectedItem();
return selectedModelSize != null
&& selectedModelSize.size() == model.getParameterSize();
})
.toList();
var models = selectedModel.filterSelectedModelsBySize(
(ModelSize) modelSizeComboBoxModel.getSelectedItem());
if (!models.isEmpty()) {
huggingFaceComboBoxModel.removeAllElements();
huggingFaceComboBoxModel.addAll(models);
@ -494,11 +481,4 @@ public class LlamaModelPreferencesForm {
private record ModelDetails(double fileSize, double maxRAMRequired) {
}
private record ModelSize(int size) {
@Override
public String toString() {
return size + "B";
}
}
}

View file

@ -229,7 +229,7 @@ public class ChatToolWindowTabPanel implements Disposable {
private void call(CallParameters callParameters, ResponsePanel responsePanel) {
var responseContainer = (ChatMessageResponseBody) responsePanel.getContent();
if (!CompletionRequestService.getInstance().isRequestAllowed()) {
if (!CompletionRequestService.getInstance().isAllowed()) {
responseContainer.displayMissingCredential();
return;
}
@ -359,7 +359,7 @@ public class ChatToolWindowTabPanel implements Disposable {
gbc.fill = GridBagConstraints.HORIZONTAL;
gbc.gridy = 1;
rootPanel.add(
createUserPromptPanel(GeneralSettings.getCurrentState().getSelectedService()), gbc);
createUserPromptPanel(GeneralSettings.getSelectedService()), gbc);
return rootPanel;
}
}

View file

@ -28,7 +28,7 @@ public class ChatToolWindowScrollablePanel extends ScrollablePanel {
public void displayLandingView(JComponent landingView) {
clearAll();
add(landingView);
if (GeneralSettings.getCurrentState().getSelectedService() == ServiceType.CODEGPT
if (GeneralSettings.isSelected(ServiceType.CODEGPT)
&& !CredentialsStore.INSTANCE.isCredentialSet(CredentialKey.CODEGPT_API_KEY)) {
var panel = new ResponsePanel()

View file

@ -159,7 +159,7 @@ public class ModelComboBoxAction extends ComboBoxAction {
if (!YouUserManager.getInstance().isSubscribed()
&& youSettings.getChatMode() != YouCompletionMode.DEFAULT) {
youSettings.setChatMode(YouCompletionMode.DEFAULT);
updateTemplatePresentation(GeneralSettings.getCurrentState().getSelectedService());
updateTemplatePresentation(GeneralSettings.getSelectedService());
}
}
);
@ -240,9 +240,11 @@ public class ModelComboBoxAction extends ComboBoxAction {
private String getSelectedHuggingFace() {
var huggingFaceModel = LlamaSettings.getCurrentState().getHuggingFaceModel();
var llamaModel = LlamaModel.findByHuggingFaceModel(huggingFaceModel);
return format(
"%s %dB (Q%d)",
LlamaModel.findByHuggingFaceModel(huggingFaceModel).getLabel(),
"%s %s %dB (Q%d)",
llamaModel.getDownloadedMarker(),
llamaModel.getLabel(),
huggingFaceModel.getParameterSize(),
huggingFaceModel.getQuantization());
}

View file

@ -163,7 +163,7 @@ public class TotalTokensPanel extends JPanel {
}
private String getIconToolTipText(String html) {
if (GeneralSettings.getCurrentState().getSelectedService() != ServiceType.OPENAI) {
if (!GeneralSettings.isSelected(ServiceType.OPENAI)) {
return """
<html
<p style="margin: 4px 0;">

View file

@ -191,7 +191,7 @@ public class UserPromptTextArea extends JPanel {
handleSubmit();
}
}));
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
var selectedService = GeneralSettings.getSelectedService();
if (selectedService == ANTHROPIC
|| selectedService == OLLAMA
|| (selectedService == OPENAI

View file

@ -17,35 +17,28 @@ abstract class CodeCompletionFeatureToggleActions(
private val enableFeatureAction: Boolean
) : DumbAwareAction() {
override fun actionPerformed(e: AnActionEvent) {
when (GeneralSettings.getCurrentState().selectedService) {
CODEGPT ->
service<CodeGPTServiceSettings>().state.codeCompletionSettings.codeCompletionsEnabled
override fun actionPerformed(e: AnActionEvent) = when (GeneralSettings.getSelectedService()) {
CODEGPT -> service<CodeGPTServiceSettings>().state.codeCompletionSettings::codeCompletionsEnabled::set
OPENAI ->
OpenAISettings.getCurrentState().isCodeCompletionsEnabled = enableFeatureAction
OPENAI -> OpenAISettings.getCurrentState()::setCodeCompletionsEnabled
LLAMA_CPP ->
LlamaSettings.getCurrentState().isCodeCompletionsEnabled = enableFeatureAction
LLAMA_CPP -> LlamaSettings.getCurrentState()::setCodeCompletionsEnabled
OLLAMA -> service<OllamaSettings>().state.codeCompletionsEnabled = enableFeatureAction
CUSTOM_OPENAI -> service<CustomServiceSettings>().state
.codeCompletionSettings
.codeCompletionsEnabled = enableFeatureAction
OLLAMA -> service<OllamaSettings>().state::codeCompletionsEnabled::set
ANTHROPIC,
AZURE,
YOU,
GOOGLE,
null -> { /* no-op for these services */
}
}
}
CUSTOM_OPENAI -> service<CustomServiceSettings>().state.codeCompletionSettings::codeCompletionsEnabled::set
ANTHROPIC,
AZURE,
YOU,
GOOGLE,
null -> { _: Boolean -> Unit } // no-op for these services
}(enableFeatureAction)
override fun update(e: AnActionEvent) {
val selectedService = GeneralSettings.getCurrentState().selectedService
val selectedService = GeneralSettings.getSelectedService()
val codeCompletionEnabled =
service<CodeCompletionService>().isCodeCompletionsEnabled(selectedService)
e.project?.service<CodeCompletionService>()?.isCodeCompletionsEnabled(selectedService) ?: false
e.presentation.isVisible = codeCompletionEnabled != enableFeatureAction
e.presentation.isEnabled = when (selectedService) {
CODEGPT,

View file

@ -37,7 +37,7 @@ class CodeCompletionService {
requestDetails: InfillRequestDetails,
eventListener: CompletionEventListener<String>
): EventSource =
when (val selectedService = GeneralSettings.getCurrentState().selectedService) {
when (val selectedService = GeneralSettings.getSelectedService()) {
CODEGPT -> CompletionClientProvider.getCodeGPTClient()
.getCompletionAsync(buildCodeGPTRequest(requestDetails), eventListener)
@ -56,4 +56,4 @@ class CodeCompletionService {
else -> throw IllegalArgumentException("Code completion not supported for ${selectedService.name}")
}
}
}

View file

@ -67,7 +67,7 @@ class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
}
override fun isEnabled(event: InlineCompletionEvent): Boolean {
val selectedService = GeneralSettings.getCurrentState().selectedService
val selectedService = GeneralSettings.getSelectedService()
val codeCompletionsEnabled = when (selectedService) {
ServiceType.CODEGPT -> service<CodeGPTServiceSettings>().state.codeCompletionSettings.codeCompletionsEnabled
ServiceType.OPENAI -> OpenAISettings.getCurrentState().isCodeCompletionsEnabled