diff --git a/src/main/cpp/llama.cpp b/src/main/cpp/llama.cpp index 46e12c46..e0f55618 160000 --- a/src/main/cpp/llama.cpp +++ b/src/main/cpp/llama.cpp @@ -1 +1 @@ -Subproject commit 46e12c4692a37bdd31a0432fc5153d7d22bc7f72 +Subproject commit e0f556186b6e1f2b7032a1479edf5e89e2b1bd86 diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java index c9d58a97..509d1319 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java @@ -214,6 +214,13 @@ public enum LlamaModel { return String.join(" ", getDownloadedMarker(), label, getFormattedModelSizeRange()); } + /** + * Server started: {@code CodeLlama 7B 4-bit}. + */ + public @NotNull String toString(@NotNull HuggingFaceModel hfm) { + return "%s %dB %d-bit".formatted(label, hfm.getParameterSize(), hfm.getQuantization()); + } + public String getLabel() { return label; } @@ -234,6 +241,16 @@ public enum LlamaModel { return huggingFaceModels; } + /** + * Downloaded model with the biggest parameter size, otherwise first. + */ + public HuggingFaceModel getLastExistingModelOrFirst() { + return huggingFaceModels.stream() + .filter(HuggingFaceModel::isDownloaded) + .max(Comparator.comparing(HuggingFaceModel::getParameterSize)) + .orElse(huggingFaceModels.get(0)); + } + public String getFormattedModelSizeRange() { var parameters = huggingFaceModels.stream() .map(HuggingFaceModel::getParameterSize) diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java index 47ed4de2..4b270ea7 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/llama/form/LlamaModelPreferencesForm.java @@ -113,21 +113,15 @@ public class LlamaModelPreferencesForm { var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class); huggingFaceModelComboBox.setEnabled(!llamaServerAgent.isServerRunning()); var modelSizeComboBoxModel = new DefaultComboBoxModel(); - var initialModelSizes = llamaModel.getSortedUniqueModelSizes(); - modelSizeComboBoxModel.addAll(initialModelSizes); - var selectedModelSize = initialModelSizes.stream() - .filter(ms -> ms.size() == llm.getParameterSize()) - .findFirst().orElse(initialModelSizes.get(0)); - modelSizeComboBoxModel.setSelectedItem(selectedModelSize); var modelComboBoxModel = new EnumComboBoxModel<>(LlamaModel.class); - modelComboBox = createModelComboBox(modelComboBoxModel, llamaModel, modelSizeComboBoxModel); + modelComboBox = createModelComboBox( + modelComboBoxModel, llamaModel, llm, llamaServerAgent, modelSizeComboBoxModel); modelComboBox.setEnabled(!llamaServerAgent.isServerRunning()); modelSizeComboBox = createModelSizeComboBox( modelComboBoxModel, modelSizeComboBoxModel, + llamaServerAgent, huggingFaceComboBoxModel); - modelSizeComboBox.setEnabled( - initialModelSizes.size() > 1 && !llamaServerAgent.isServerRunning()); browsableCustomModelTextField = createBrowsableCustomModelTextField( !llamaServerAgent.isServerRunning()); browsableCustomModelTextField.setText(llamaSettings.getCustomLlamaModelPath()); @@ -310,40 +304,57 @@ public class LlamaModelPreferencesForm { private ComboBox createModelComboBox( EnumComboBoxModel llamaModelEnumComboBoxModel, LlamaModel llamaModel, + HuggingFaceModel llm, + LlamaServerAgent llamaServerAgent, DefaultComboBoxModel modelSizeComboBoxModel) { var comboBox = new ComboBox<>(llamaModelEnumComboBoxModel); comboBox.setPreferredSize(new Dimension(280, comboBox.getPreferredSize().height)); comboBox.setSelectedItem(llamaModel); + initializeModelSizes(llamaModel, llm, modelSizeComboBoxModel); comboBox.addItemListener(e -> { var selectedModel = (LlamaModel) e.getItem(); - var modelSizes = selectedModel.getSortedUniqueModelSizes(); - - modelSizeComboBoxModel.removeAllElements(); - modelSizeComboBoxModel.addAll(modelSizes); - modelSizeComboBoxModel.setSelectedItem(modelSizes.get(0)); - modelSizeComboBox.setEnabled(modelSizes.size() > 1); - - var huggingFaceModels = selectedModel.filterSelectedModelsBySize( - (ModelSize) modelSizeComboBoxModel.getSelectedItem()); - + var hfm = selectedModel.getLastExistingModelOrFirst(); + var modelSize = initializeModelSizes(selectedModel, hfm, modelSizeComboBoxModel); + var huggingFaceModels = selectedModel.filterSelectedModelsBySize(modelSize); huggingFaceComboBoxModel.removeAllElements(); huggingFaceComboBoxModel.addAll(huggingFaceModels); - huggingFaceComboBoxModel.setSelectedItem(huggingFaceModels.get(0)); + huggingFaceComboBoxModel.setSelectedItem(hfm); + modelSizeComboBox.setEnabled( + modelSizeComboBox.getModel().getSize() > 1 && !llamaServerAgent.isServerRunning()); }); return comboBox; } + private static ModelSize initializeModelSizes( + LlamaModel llamaModel, + HuggingFaceModel hfm, + DefaultComboBoxModel modelSizeComboBoxModel) { + var modelSizes = llamaModel.getSortedUniqueModelSizes(); + modelSizeComboBoxModel.removeAllElements(); + modelSizeComboBoxModel.addAll(modelSizes); + var selectedModelSize = modelSizes.stream() + .filter(ms -> ms.size() == hfm.getParameterSize()) + .findFirst().orElse(modelSizes.get(0)); + modelSizeComboBoxModel.setSelectedItem(selectedModelSize); + return selectedModelSize; + } + private ComboBox createModelSizeComboBox( EnumComboBoxModel llamaModelComboBoxModel, DefaultComboBoxModel modelSizeComboBoxModel, + LlamaServerAgent llamaServerAgent, DefaultComboBoxModel huggingFaceComboBoxModel) { var comboBox = new ComboBox<>(modelSizeComboBoxModel); comboBox.setPreferredSize(modelComboBox.getPreferredSize()); comboBox.setSelectedItem(modelSizeComboBoxModel.getSelectedItem()); + comboBox.setEnabled( + modelSizeComboBoxModel.getSize() > 1 && !llamaServerAgent.isServerRunning()); comboBox.addItemListener(e -> { var selectedModel = llamaModelComboBoxModel.getSelectedItem(); var models = selectedModel.filterSelectedModelsBySize( (ModelSize) modelSizeComboBoxModel.getSelectedItem()); + comboBox.setEnabled( + modelSizeComboBoxModel.getSize() > 1 && !llamaServerAgent.isServerRunning()); if (!models.isEmpty()) { huggingFaceComboBoxModel.removeAllElements(); huggingFaceComboBoxModel.addAll(models); diff --git a/src/main/java/ee/carlrobert/codegpt/ui/OverlayUtil.java b/src/main/java/ee/carlrobert/codegpt/ui/OverlayUtil.java index 610e03e0..4996ff8f 100644 --- a/src/main/java/ee/carlrobert/codegpt/ui/OverlayUtil.java +++ b/src/main/java/ee/carlrobert/codegpt/ui/OverlayUtil.java @@ -1,5 +1,6 @@ package ee.carlrobert.codegpt.ui; +import static com.intellij.notification.NotificationType.INFORMATION; import static com.intellij.openapi.ui.Messages.CANCEL; import static com.intellij.openapi.ui.Messages.OK; import static ee.carlrobert.codegpt.Icons.Default; @@ -10,6 +11,7 @@ import com.intellij.execution.ExecutionBundle; import com.intellij.notification.Notification; import com.intellij.notification.NotificationType; import com.intellij.notification.Notifications; +import com.intellij.openapi.actionSystem.AnAction; import com.intellij.openapi.actionSystem.AnActionEvent; import com.intellij.openapi.project.Project; import com.intellij.openapi.ui.DoNotAskOption; @@ -26,6 +28,7 @@ import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; import ee.carlrobert.codegpt.util.EditorUtil; import java.awt.Point; import java.awt.event.MouseEvent; +import java.util.Arrays; import javax.swing.JComponent; import org.jetbrains.annotations.NotNull; @@ -37,30 +40,53 @@ public class OverlayUtil { private OverlayUtil() { } - public static Notification getDefaultNotification(String content, NotificationType type) { - return new Notification(NOTIFICATION_GROUP_ID, "CodeGPT", content, type); + public static Notification getDefaultNotification( + @NotNull String content, @NotNull AnAction... actions) { + return getDefaultNotification(content, INFORMATION, actions); } - public static Notification getStickyNotification(String content, NotificationType type) { - return new Notification(NOTIFICATION_GROUP_STICKY_ID, "CodeGPT", content, type); - } - - public static Notification showNotification(String content) { - return showNotification(content, NotificationType.INFORMATION); - } - - public static Notification showNotification(String content, NotificationType type) { - var notification = getDefaultNotification(content, type); - Notifications.Bus.notify(notification); + public static Notification getDefaultNotification( + @NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) { + var notification = new Notification(NOTIFICATION_GROUP_ID, "CodeGPT", content, type); + Arrays.asList(actions).forEach(notification::addAction); return notification; } - public static Notification stickyNotification(String content) { - return stickyNotification(content, NotificationType.INFORMATION); + public static Notification getStickyNotification( + @NotNull String content, @NotNull AnAction... actions) { + return getStickyNotification(content, INFORMATION, actions); } - public static Notification stickyNotification(String content, NotificationType type) { - var notification = getStickyNotification(content, type); + public static Notification getStickyNotification( + @NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) { + var notification = new Notification(NOTIFICATION_GROUP_STICKY_ID, "CodeGPT", content, type); + Arrays.asList(actions).forEach(notification::addAction); + return notification; + } + + public static Notification showNotification( + @NotNull String content, @NotNull AnAction... actions) { + return showNotification(content, INFORMATION, actions); + } + + public static Notification showNotification( + @NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) { + return notify(getDefaultNotification(content, type, actions)); + } + + public static Notification stickyNotification( + @NotNull String content, @NotNull AnAction... actions) { + return stickyNotification(content, INFORMATION, actions); + } + + public static Notification stickyNotification( + @NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) { + return notify(getStickyNotification(content, type, actions)); + } + + public static @NotNull Notification notify( + @NotNull Notification notification, @NotNull AnAction... actions) { + Arrays.asList(actions).forEach(notification::addAction); Notifications.Bus.notify(notification); return notification; } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/actions/LlamaServerToggleActions.kt b/src/main/kotlin/ee/carlrobert/codegpt/actions/LlamaServerToggleActions.kt index 3dbf1a4b..74f84430 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/actions/LlamaServerToggleActions.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/actions/LlamaServerToggleActions.kt @@ -1,12 +1,14 @@ package ee.carlrobert.codegpt.actions import com.intellij.notification.Notification +import com.intellij.notification.NotificationAction.createSimpleExpiring import com.intellij.openapi.actionSystem.ActionManager import com.intellij.openapi.actionSystem.ActionUpdateThread import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.components.service import com.intellij.openapi.project.DumbAwareAction import ee.carlrobert.codegpt.CodeGPTBundle +import ee.carlrobert.codegpt.completions.llama.LlamaModel.findByHuggingFaceModel import ee.carlrobert.codegpt.completions.llama.LlamaServerAgent import ee.carlrobert.codegpt.completions.llama.LlamaServerStartupParams import ee.carlrobert.codegpt.settings.GeneralSettings @@ -19,6 +21,13 @@ import ee.carlrobert.codegpt.ui.OverlayUtil.showNotification import ee.carlrobert.codegpt.ui.OverlayUtil.stickyNotification import java.util.function.Consumer +private const val STARTING = "settingsConfigurable.service.llama.progress.startingServer" +private const val RUNNING = "settingsConfigurable.service.llama.progress.serverRunning" +private const val STOPPING = "settingsConfigurable.service.llama.progress.stoppingServer" +private const val STOPPED = "settingsConfigurable.service.llama.progress.serverStopped" +private const val START = "settingsConfigurable.service.llama.stopServer.opposite" +private const val STOP = "settingsConfigurable.service.llama.startServer.opposite" + /** * Start or stop server (if selected model exists) showing notifications */ @@ -27,13 +36,15 @@ abstract class LlamaServerToggleActions( ) : DumbAwareAction() { companion object { fun expireOtherNotification(start: Boolean) { - (ActionManager.getInstance().getAction( - if (start) "statusbar.stopServer" else "statusbar.startServer" - ) as LlamaServerToggleActions).apply { + getAction(start).apply { this.notification?.expire() this.notification = null } } + + private fun getAction(start: Boolean) = ActionManager.getInstance().getAction(getId(start)) as LlamaServerToggleActions + + private fun getId(start: Boolean) = if (start) "statusbar.stopServer" else "statusbar.startServer" } var notification: Notification? = null @@ -43,39 +54,63 @@ abstract class LlamaServerToggleActions( notification?.expire() expireOtherNotification(startServer) val llamaServerAgent = service() + val serverName = LlamaSettings.getInstance().state.huggingFaceModel.let { findByHuggingFaceModel(it).toString(it) } if (startServer) { - notification = stickyNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.startingServer")) - val serverProgressPanel = ServerProgressPanel() - llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel) - val settings = LlamaSettings.getInstance().state - llamaServerAgent.startAgent( - LlamaServerStartupParams( - LlamaSettings.getInstance().actualModelPath, - settings.contextSize, - settings.threads, - settings.serverPort, - getAdditionalParametersList(settings.additionalParameters), - getAdditionalParametersList(settings.additionalBuildParameters) - ), - serverProgressPanel, - { - notification?.expire() - notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverRunning")) - }, - { - Consumer { _: ServerProgressPanel -> - notification?.expire() - notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverStopped")) - } - }) + start(serverName, llamaServerAgent) } else { - notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.stoppingServer")) - llamaServerAgent.stopAgent() - notification?.expire() - notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverStopped")) + stop(serverName, llamaServerAgent) } } + private fun start(serverName: String, llamaServerAgent: LlamaServerAgent) { + notification = stickyNotification(formatMsg(STARTING, serverName), + createSimpleExpiring(CodeGPTBundle.get(STOP)) { stop(serverName, llamaServerAgent) }) + val serverProgressPanel = ServerProgressPanel() + llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel) + val settings = LlamaSettings.getInstance().state + llamaServerAgent.startAgent( + LlamaServerStartupParams( + LlamaSettings.getInstance().actualModelPath, + settings.contextSize, + settings.threads, + settings.serverPort, + getAdditionalParametersList(settings.additionalParameters), + getAdditionalParametersList(settings.additionalBuildParameters) + ), + serverProgressPanel, + { + notification?.expire() + notification = notification(RUNNING, false, serverName, llamaServerAgent) + }, + { + Consumer { _: ServerProgressPanel -> + notification?.expire() + notification = notification(STOPPED, true, serverName, llamaServerAgent) + } + }) + } + + private fun stop(serverName: String, llamaServerAgent: LlamaServerAgent) { + notification = showNotification(formatMsg(STOPPING, serverName)) + llamaServerAgent.stopAgent() + notification?.expire() + notification = notification(STOPPED, true, serverName, llamaServerAgent) + } + + private fun notification(id: String, nextStart: Boolean, serverName: String, llamaServerAgent: LlamaServerAgent) = + showNotification(formatMsg(id, serverName), + createSimpleExpiring(CodeGPTBundle.get(if (nextStart) START else STOP)) { + if (nextStart) start(serverName, llamaServerAgent) else stop(serverName, llamaServerAgent) + }) + + // "Starting server..." -> "Starting server: CodeLlama 7B 4-bit ..." + // "Stopped server" -> "Stopped server: CodeLlama 7B 4-bit" + private fun formatMsg(id: String, serverName: String): String { + val msg = CodeGPTBundle.get(id) + val points = msg.endsWith("...") + return msg.let { if (points) it.substringBeforeLast("...") else it } + ": " + serverName + (if (points) " ..." else "") + } + override fun update(e: AnActionEvent) { val llamaRunnable = isRunnable(LlamaSettings.getInstance().state.huggingFaceModel) val serverRunning = llamaRunnable && service().isServerRunning diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/ollama/OllamaSettingsForm.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/ollama/OllamaSettingsForm.kt index 9825a2ca..208d85b9 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/ollama/OllamaSettingsForm.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/ollama/OllamaSettingsForm.kt @@ -142,7 +142,7 @@ class OllamaSettingsForm { NotificationType.ERROR ) } else { - OverlayUtil.showNotification(ex.message, NotificationType.ERROR) + OverlayUtil.showNotification(ex.message ?: "Error", NotificationType.ERROR) } disableModelComboBoxWithPlaceholder(DefaultComboBoxModel(arrayOf("Unable to load models"))) } diff --git a/src/main/resources/messages/codegpt.properties b/src/main/resources/messages/codegpt.properties index 7522e5f7..e851963b 100644 --- a/src/main/resources/messages/codegpt.properties +++ b/src/main/resources/messages/codegpt.properties @@ -84,11 +84,13 @@ settingsConfigurable.service.llama.additionalBuildParameters.comment=Addit settingsConfigurable.service.llama.baseHost.label=Base host: settingsConfigurable.service.llama.baseHost.comment=URL to existing LLama server settingsConfigurable.service.llama.startServer.label=Start server +settingsConfigurable.service.llama.startServer.opposite=Stop settingsConfigurable.service.llama.stopServer.label=Stop server +settingsConfigurable.service.llama.stopServer.opposite=Start settingsConfigurable.service.llama.progress.serverRunning=Server running settingsConfigurable.service.llama.progress.serverStopped=Server stopped -settingsConfigurable.service.llama.progress.stoppingServer=Stopping a server... -settingsConfigurable.service.llama.progress.startingServer=Starting a server... +settingsConfigurable.service.llama.progress.stoppingServer=Stopping server... +settingsConfigurable.service.llama.progress.startingServer=Starting server... settingsConfigurable.service.llama.progress.downloadingModel.title=Downloading Model settingsConfigurable.service.llama.progress.downloadingModelIndicator.text=Downloading %s... settingsConfigurable.service.llama.overlay.modelNotDownloaded.text=Model is not downloaded