feat: Show server name in start/stop notifications (#546)

* feat: Show server name in start/stop notifications

* feat: Show opposite action in notification

* feat: Pre-select biggest downloaded parameter size on model change

* chore: Update to latest llama.cpp fixes (2024-05-14)
This commit is contained in:
Rene Leonhardt 2024-05-14 20:26:22 +02:00 committed by GitHub
parent 6de38103d9
commit 8e5ba8158d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 163 additions and 72 deletions

@ -1 +1 @@
Subproject commit 46e12c4692a37bdd31a0432fc5153d7d22bc7f72
Subproject commit e0f556186b6e1f2b7032a1479edf5e89e2b1bd86

View file

@ -214,6 +214,13 @@ public enum LlamaModel {
return String.join(" ", getDownloadedMarker(), label, getFormattedModelSizeRange());
}
/**
* Server started: {@code CodeLlama 7B 4-bit}.
*/
public @NotNull String toString(@NotNull HuggingFaceModel hfm) {
return "%s %dB %d-bit".formatted(label, hfm.getParameterSize(), hfm.getQuantization());
}
public String getLabel() {
return label;
}
@ -234,6 +241,16 @@ public enum LlamaModel {
return huggingFaceModels;
}
/**
* Downloaded model with the biggest parameter size, otherwise first.
*/
public HuggingFaceModel getLastExistingModelOrFirst() {
return huggingFaceModels.stream()
.filter(HuggingFaceModel::isDownloaded)
.max(Comparator.comparing(HuggingFaceModel::getParameterSize))
.orElse(huggingFaceModels.get(0));
}
public String getFormattedModelSizeRange() {
var parameters = huggingFaceModels.stream()
.map(HuggingFaceModel::getParameterSize)

View file

@ -113,21 +113,15 @@ public class LlamaModelPreferencesForm {
var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class);
huggingFaceModelComboBox.setEnabled(!llamaServerAgent.isServerRunning());
var modelSizeComboBoxModel = new DefaultComboBoxModel<ModelSize>();
var initialModelSizes = llamaModel.getSortedUniqueModelSizes();
modelSizeComboBoxModel.addAll(initialModelSizes);
var selectedModelSize = initialModelSizes.stream()
.filter(ms -> ms.size() == llm.getParameterSize())
.findFirst().orElse(initialModelSizes.get(0));
modelSizeComboBoxModel.setSelectedItem(selectedModelSize);
var modelComboBoxModel = new EnumComboBoxModel<>(LlamaModel.class);
modelComboBox = createModelComboBox(modelComboBoxModel, llamaModel, modelSizeComboBoxModel);
modelComboBox = createModelComboBox(
modelComboBoxModel, llamaModel, llm, llamaServerAgent, modelSizeComboBoxModel);
modelComboBox.setEnabled(!llamaServerAgent.isServerRunning());
modelSizeComboBox = createModelSizeComboBox(
modelComboBoxModel,
modelSizeComboBoxModel,
llamaServerAgent,
huggingFaceComboBoxModel);
modelSizeComboBox.setEnabled(
initialModelSizes.size() > 1 && !llamaServerAgent.isServerRunning());
browsableCustomModelTextField = createBrowsableCustomModelTextField(
!llamaServerAgent.isServerRunning());
browsableCustomModelTextField.setText(llamaSettings.getCustomLlamaModelPath());
@ -310,40 +304,57 @@ public class LlamaModelPreferencesForm {
private ComboBox<LlamaModel> createModelComboBox(
EnumComboBoxModel<LlamaModel> llamaModelEnumComboBoxModel,
LlamaModel llamaModel,
HuggingFaceModel llm,
LlamaServerAgent llamaServerAgent,
DefaultComboBoxModel<ModelSize> modelSizeComboBoxModel) {
var comboBox = new ComboBox<>(llamaModelEnumComboBoxModel);
comboBox.setPreferredSize(new Dimension(280, comboBox.getPreferredSize().height));
comboBox.setSelectedItem(llamaModel);
initializeModelSizes(llamaModel, llm, modelSizeComboBoxModel);
comboBox.addItemListener(e -> {
var selectedModel = (LlamaModel) e.getItem();
var modelSizes = selectedModel.getSortedUniqueModelSizes();
modelSizeComboBoxModel.removeAllElements();
modelSizeComboBoxModel.addAll(modelSizes);
modelSizeComboBoxModel.setSelectedItem(modelSizes.get(0));
modelSizeComboBox.setEnabled(modelSizes.size() > 1);
var huggingFaceModels = selectedModel.filterSelectedModelsBySize(
(ModelSize) modelSizeComboBoxModel.getSelectedItem());
var hfm = selectedModel.getLastExistingModelOrFirst();
var modelSize = initializeModelSizes(selectedModel, hfm, modelSizeComboBoxModel);
var huggingFaceModels = selectedModel.filterSelectedModelsBySize(modelSize);
huggingFaceComboBoxModel.removeAllElements();
huggingFaceComboBoxModel.addAll(huggingFaceModels);
huggingFaceComboBoxModel.setSelectedItem(huggingFaceModels.get(0));
huggingFaceComboBoxModel.setSelectedItem(hfm);
modelSizeComboBox.setEnabled(
modelSizeComboBox.getModel().getSize() > 1 && !llamaServerAgent.isServerRunning());
});
return comboBox;
}
private static ModelSize initializeModelSizes(
LlamaModel llamaModel,
HuggingFaceModel hfm,
DefaultComboBoxModel<ModelSize> modelSizeComboBoxModel) {
var modelSizes = llamaModel.getSortedUniqueModelSizes();
modelSizeComboBoxModel.removeAllElements();
modelSizeComboBoxModel.addAll(modelSizes);
var selectedModelSize = modelSizes.stream()
.filter(ms -> ms.size() == hfm.getParameterSize())
.findFirst().orElse(modelSizes.get(0));
modelSizeComboBoxModel.setSelectedItem(selectedModelSize);
return selectedModelSize;
}
private ComboBox<ModelSize> createModelSizeComboBox(
EnumComboBoxModel<LlamaModel> llamaModelComboBoxModel,
DefaultComboBoxModel<ModelSize> modelSizeComboBoxModel,
LlamaServerAgent llamaServerAgent,
DefaultComboBoxModel<HuggingFaceModel> huggingFaceComboBoxModel) {
var comboBox = new ComboBox<>(modelSizeComboBoxModel);
comboBox.setPreferredSize(modelComboBox.getPreferredSize());
comboBox.setSelectedItem(modelSizeComboBoxModel.getSelectedItem());
comboBox.setEnabled(
modelSizeComboBoxModel.getSize() > 1 && !llamaServerAgent.isServerRunning());
comboBox.addItemListener(e -> {
var selectedModel = llamaModelComboBoxModel.getSelectedItem();
var models = selectedModel.filterSelectedModelsBySize(
(ModelSize) modelSizeComboBoxModel.getSelectedItem());
comboBox.setEnabled(
modelSizeComboBoxModel.getSize() > 1 && !llamaServerAgent.isServerRunning());
if (!models.isEmpty()) {
huggingFaceComboBoxModel.removeAllElements();
huggingFaceComboBoxModel.addAll(models);

View file

@ -1,5 +1,6 @@
package ee.carlrobert.codegpt.ui;
import static com.intellij.notification.NotificationType.INFORMATION;
import static com.intellij.openapi.ui.Messages.CANCEL;
import static com.intellij.openapi.ui.Messages.OK;
import static ee.carlrobert.codegpt.Icons.Default;
@ -10,6 +11,7 @@ import com.intellij.execution.ExecutionBundle;
import com.intellij.notification.Notification;
import com.intellij.notification.NotificationType;
import com.intellij.notification.Notifications;
import com.intellij.openapi.actionSystem.AnAction;
import com.intellij.openapi.actionSystem.AnActionEvent;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.ui.DoNotAskOption;
@ -26,6 +28,7 @@ import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.codegpt.util.EditorUtil;
import java.awt.Point;
import java.awt.event.MouseEvent;
import java.util.Arrays;
import javax.swing.JComponent;
import org.jetbrains.annotations.NotNull;
@ -37,30 +40,53 @@ public class OverlayUtil {
private OverlayUtil() {
}
public static Notification getDefaultNotification(String content, NotificationType type) {
return new Notification(NOTIFICATION_GROUP_ID, "CodeGPT", content, type);
public static Notification getDefaultNotification(
@NotNull String content, @NotNull AnAction... actions) {
return getDefaultNotification(content, INFORMATION, actions);
}
public static Notification getStickyNotification(String content, NotificationType type) {
return new Notification(NOTIFICATION_GROUP_STICKY_ID, "CodeGPT", content, type);
}
public static Notification showNotification(String content) {
return showNotification(content, NotificationType.INFORMATION);
}
public static Notification showNotification(String content, NotificationType type) {
var notification = getDefaultNotification(content, type);
Notifications.Bus.notify(notification);
public static Notification getDefaultNotification(
@NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) {
var notification = new Notification(NOTIFICATION_GROUP_ID, "CodeGPT", content, type);
Arrays.asList(actions).forEach(notification::addAction);
return notification;
}
public static Notification stickyNotification(String content) {
return stickyNotification(content, NotificationType.INFORMATION);
public static Notification getStickyNotification(
@NotNull String content, @NotNull AnAction... actions) {
return getStickyNotification(content, INFORMATION, actions);
}
public static Notification stickyNotification(String content, NotificationType type) {
var notification = getStickyNotification(content, type);
public static Notification getStickyNotification(
@NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) {
var notification = new Notification(NOTIFICATION_GROUP_STICKY_ID, "CodeGPT", content, type);
Arrays.asList(actions).forEach(notification::addAction);
return notification;
}
public static Notification showNotification(
@NotNull String content, @NotNull AnAction... actions) {
return showNotification(content, INFORMATION, actions);
}
public static Notification showNotification(
@NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) {
return notify(getDefaultNotification(content, type, actions));
}
public static Notification stickyNotification(
@NotNull String content, @NotNull AnAction... actions) {
return stickyNotification(content, INFORMATION, actions);
}
public static Notification stickyNotification(
@NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) {
return notify(getStickyNotification(content, type, actions));
}
public static @NotNull Notification notify(
@NotNull Notification notification, @NotNull AnAction... actions) {
Arrays.asList(actions).forEach(notification::addAction);
Notifications.Bus.notify(notification);
return notification;
}

View file

@ -1,12 +1,14 @@
package ee.carlrobert.codegpt.actions
import com.intellij.notification.Notification
import com.intellij.notification.NotificationAction.createSimpleExpiring
import com.intellij.openapi.actionSystem.ActionManager
import com.intellij.openapi.actionSystem.ActionUpdateThread
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.components.service
import com.intellij.openapi.project.DumbAwareAction
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.completions.llama.LlamaModel.findByHuggingFaceModel
import ee.carlrobert.codegpt.completions.llama.LlamaServerAgent
import ee.carlrobert.codegpt.completions.llama.LlamaServerStartupParams
import ee.carlrobert.codegpt.settings.GeneralSettings
@ -19,6 +21,13 @@ import ee.carlrobert.codegpt.ui.OverlayUtil.showNotification
import ee.carlrobert.codegpt.ui.OverlayUtil.stickyNotification
import java.util.function.Consumer
private const val STARTING = "settingsConfigurable.service.llama.progress.startingServer"
private const val RUNNING = "settingsConfigurable.service.llama.progress.serverRunning"
private const val STOPPING = "settingsConfigurable.service.llama.progress.stoppingServer"
private const val STOPPED = "settingsConfigurable.service.llama.progress.serverStopped"
private const val START = "settingsConfigurable.service.llama.stopServer.opposite"
private const val STOP = "settingsConfigurable.service.llama.startServer.opposite"
/**
* Start or stop server (if selected model exists) showing notifications
*/
@ -27,13 +36,15 @@ abstract class LlamaServerToggleActions(
) : DumbAwareAction() {
companion object {
fun expireOtherNotification(start: Boolean) {
(ActionManager.getInstance().getAction(
if (start) "statusbar.stopServer" else "statusbar.startServer"
) as LlamaServerToggleActions).apply {
getAction(start).apply {
this.notification?.expire()
this.notification = null
}
}
private fun getAction(start: Boolean) = ActionManager.getInstance().getAction(getId(start)) as LlamaServerToggleActions
private fun getId(start: Boolean) = if (start) "statusbar.stopServer" else "statusbar.startServer"
}
var notification: Notification? = null
@ -43,39 +54,63 @@ abstract class LlamaServerToggleActions(
notification?.expire()
expireOtherNotification(startServer)
val llamaServerAgent = service<LlamaServerAgent>()
val serverName = LlamaSettings.getInstance().state.huggingFaceModel.let { findByHuggingFaceModel(it).toString(it) }
if (startServer) {
notification = stickyNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.startingServer"))
val serverProgressPanel = ServerProgressPanel()
llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel)
val settings = LlamaSettings.getInstance().state
llamaServerAgent.startAgent(
LlamaServerStartupParams(
LlamaSettings.getInstance().actualModelPath,
settings.contextSize,
settings.threads,
settings.serverPort,
getAdditionalParametersList(settings.additionalParameters),
getAdditionalParametersList(settings.additionalBuildParameters)
),
serverProgressPanel,
{
notification?.expire()
notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverRunning"))
},
{
Consumer<ServerProgressPanel> { _: ServerProgressPanel ->
notification?.expire()
notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverStopped"))
}
})
start(serverName, llamaServerAgent)
} else {
notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.stoppingServer"))
llamaServerAgent.stopAgent()
notification?.expire()
notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverStopped"))
stop(serverName, llamaServerAgent)
}
}
private fun start(serverName: String, llamaServerAgent: LlamaServerAgent) {
notification = stickyNotification(formatMsg(STARTING, serverName),
createSimpleExpiring(CodeGPTBundle.get(STOP)) { stop(serverName, llamaServerAgent) })
val serverProgressPanel = ServerProgressPanel()
llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel)
val settings = LlamaSettings.getInstance().state
llamaServerAgent.startAgent(
LlamaServerStartupParams(
LlamaSettings.getInstance().actualModelPath,
settings.contextSize,
settings.threads,
settings.serverPort,
getAdditionalParametersList(settings.additionalParameters),
getAdditionalParametersList(settings.additionalBuildParameters)
),
serverProgressPanel,
{
notification?.expire()
notification = notification(RUNNING, false, serverName, llamaServerAgent)
},
{
Consumer<ServerProgressPanel> { _: ServerProgressPanel ->
notification?.expire()
notification = notification(STOPPED, true, serverName, llamaServerAgent)
}
})
}
private fun stop(serverName: String, llamaServerAgent: LlamaServerAgent) {
notification = showNotification(formatMsg(STOPPING, serverName))
llamaServerAgent.stopAgent()
notification?.expire()
notification = notification(STOPPED, true, serverName, llamaServerAgent)
}
private fun notification(id: String, nextStart: Boolean, serverName: String, llamaServerAgent: LlamaServerAgent) =
showNotification(formatMsg(id, serverName),
createSimpleExpiring(CodeGPTBundle.get(if (nextStart) START else STOP)) {
if (nextStart) start(serverName, llamaServerAgent) else stop(serverName, llamaServerAgent)
})
// "Starting server..." -> "Starting server: CodeLlama 7B 4-bit ..."
// "Stopped server" -> "Stopped server: CodeLlama 7B 4-bit"
private fun formatMsg(id: String, serverName: String): String {
val msg = CodeGPTBundle.get(id)
val points = msg.endsWith("...")
return msg.let { if (points) it.substringBeforeLast("...") else it } + ": " + serverName + (if (points) " ..." else "")
}
override fun update(e: AnActionEvent) {
val llamaRunnable = isRunnable(LlamaSettings.getInstance().state.huggingFaceModel)
val serverRunning = llamaRunnable && service<LlamaServerAgent>().isServerRunning

View file

@ -142,7 +142,7 @@ class OllamaSettingsForm {
NotificationType.ERROR
)
} else {
OverlayUtil.showNotification(ex.message, NotificationType.ERROR)
OverlayUtil.showNotification(ex.message ?: "Error", NotificationType.ERROR)
}
disableModelComboBoxWithPlaceholder(DefaultComboBoxModel(arrayOf("Unable to load models")))
}

View file

@ -84,11 +84,13 @@ settingsConfigurable.service.llama.additionalBuildParameters.comment=<html>Addit
settingsConfigurable.service.llama.baseHost.label=Base host:
settingsConfigurable.service.llama.baseHost.comment=URL to existing LLama server
settingsConfigurable.service.llama.startServer.label=Start server
settingsConfigurable.service.llama.startServer.opposite=Stop
settingsConfigurable.service.llama.stopServer.label=Stop server
settingsConfigurable.service.llama.stopServer.opposite=Start
settingsConfigurable.service.llama.progress.serverRunning=Server running
settingsConfigurable.service.llama.progress.serverStopped=Server stopped
settingsConfigurable.service.llama.progress.stoppingServer=Stopping a server...
settingsConfigurable.service.llama.progress.startingServer=Starting a server...
settingsConfigurable.service.llama.progress.stoppingServer=Stopping server...
settingsConfigurable.service.llama.progress.startingServer=Starting server...
settingsConfigurable.service.llama.progress.downloadingModel.title=Downloading Model
settingsConfigurable.service.llama.progress.downloadingModelIndicator.text=Downloading %s...
settingsConfigurable.service.llama.overlay.modelNotDownloaded.text=Model is not downloaded