mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-10 03:59:43 +00:00
feat: Show server name in start/stop notifications (#546)
* feat: Show server name in start/stop notifications * feat: Show opposite action in notification * feat: Pre-select biggest downloaded parameter size on model change * chore: Update to latest llama.cpp fixes (2024-05-14)
This commit is contained in:
parent
6de38103d9
commit
8e5ba8158d
7 changed files with 163 additions and 72 deletions
|
|
@ -1 +1 @@
|
|||
Subproject commit 46e12c4692a37bdd31a0432fc5153d7d22bc7f72
|
||||
Subproject commit e0f556186b6e1f2b7032a1479edf5e89e2b1bd86
|
||||
|
|
@ -214,6 +214,13 @@ public enum LlamaModel {
|
|||
return String.join(" ", getDownloadedMarker(), label, getFormattedModelSizeRange());
|
||||
}
|
||||
|
||||
/**
|
||||
* Server started: {@code CodeLlama 7B 4-bit}.
|
||||
*/
|
||||
public @NotNull String toString(@NotNull HuggingFaceModel hfm) {
|
||||
return "%s %dB %d-bit".formatted(label, hfm.getParameterSize(), hfm.getQuantization());
|
||||
}
|
||||
|
||||
public String getLabel() {
|
||||
return label;
|
||||
}
|
||||
|
|
@ -234,6 +241,16 @@ public enum LlamaModel {
|
|||
return huggingFaceModels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloaded model with the biggest parameter size, otherwise first.
|
||||
*/
|
||||
public HuggingFaceModel getLastExistingModelOrFirst() {
|
||||
return huggingFaceModels.stream()
|
||||
.filter(HuggingFaceModel::isDownloaded)
|
||||
.max(Comparator.comparing(HuggingFaceModel::getParameterSize))
|
||||
.orElse(huggingFaceModels.get(0));
|
||||
}
|
||||
|
||||
public String getFormattedModelSizeRange() {
|
||||
var parameters = huggingFaceModels.stream()
|
||||
.map(HuggingFaceModel::getParameterSize)
|
||||
|
|
|
|||
|
|
@ -113,21 +113,15 @@ public class LlamaModelPreferencesForm {
|
|||
var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class);
|
||||
huggingFaceModelComboBox.setEnabled(!llamaServerAgent.isServerRunning());
|
||||
var modelSizeComboBoxModel = new DefaultComboBoxModel<ModelSize>();
|
||||
var initialModelSizes = llamaModel.getSortedUniqueModelSizes();
|
||||
modelSizeComboBoxModel.addAll(initialModelSizes);
|
||||
var selectedModelSize = initialModelSizes.stream()
|
||||
.filter(ms -> ms.size() == llm.getParameterSize())
|
||||
.findFirst().orElse(initialModelSizes.get(0));
|
||||
modelSizeComboBoxModel.setSelectedItem(selectedModelSize);
|
||||
var modelComboBoxModel = new EnumComboBoxModel<>(LlamaModel.class);
|
||||
modelComboBox = createModelComboBox(modelComboBoxModel, llamaModel, modelSizeComboBoxModel);
|
||||
modelComboBox = createModelComboBox(
|
||||
modelComboBoxModel, llamaModel, llm, llamaServerAgent, modelSizeComboBoxModel);
|
||||
modelComboBox.setEnabled(!llamaServerAgent.isServerRunning());
|
||||
modelSizeComboBox = createModelSizeComboBox(
|
||||
modelComboBoxModel,
|
||||
modelSizeComboBoxModel,
|
||||
llamaServerAgent,
|
||||
huggingFaceComboBoxModel);
|
||||
modelSizeComboBox.setEnabled(
|
||||
initialModelSizes.size() > 1 && !llamaServerAgent.isServerRunning());
|
||||
browsableCustomModelTextField = createBrowsableCustomModelTextField(
|
||||
!llamaServerAgent.isServerRunning());
|
||||
browsableCustomModelTextField.setText(llamaSettings.getCustomLlamaModelPath());
|
||||
|
|
@ -310,40 +304,57 @@ public class LlamaModelPreferencesForm {
|
|||
private ComboBox<LlamaModel> createModelComboBox(
|
||||
EnumComboBoxModel<LlamaModel> llamaModelEnumComboBoxModel,
|
||||
LlamaModel llamaModel,
|
||||
HuggingFaceModel llm,
|
||||
LlamaServerAgent llamaServerAgent,
|
||||
DefaultComboBoxModel<ModelSize> modelSizeComboBoxModel) {
|
||||
var comboBox = new ComboBox<>(llamaModelEnumComboBoxModel);
|
||||
comboBox.setPreferredSize(new Dimension(280, comboBox.getPreferredSize().height));
|
||||
comboBox.setSelectedItem(llamaModel);
|
||||
initializeModelSizes(llamaModel, llm, modelSizeComboBoxModel);
|
||||
comboBox.addItemListener(e -> {
|
||||
var selectedModel = (LlamaModel) e.getItem();
|
||||
var modelSizes = selectedModel.getSortedUniqueModelSizes();
|
||||
|
||||
modelSizeComboBoxModel.removeAllElements();
|
||||
modelSizeComboBoxModel.addAll(modelSizes);
|
||||
modelSizeComboBoxModel.setSelectedItem(modelSizes.get(0));
|
||||
modelSizeComboBox.setEnabled(modelSizes.size() > 1);
|
||||
|
||||
var huggingFaceModels = selectedModel.filterSelectedModelsBySize(
|
||||
(ModelSize) modelSizeComboBoxModel.getSelectedItem());
|
||||
|
||||
var hfm = selectedModel.getLastExistingModelOrFirst();
|
||||
var modelSize = initializeModelSizes(selectedModel, hfm, modelSizeComboBoxModel);
|
||||
var huggingFaceModels = selectedModel.filterSelectedModelsBySize(modelSize);
|
||||
huggingFaceComboBoxModel.removeAllElements();
|
||||
huggingFaceComboBoxModel.addAll(huggingFaceModels);
|
||||
huggingFaceComboBoxModel.setSelectedItem(huggingFaceModels.get(0));
|
||||
huggingFaceComboBoxModel.setSelectedItem(hfm);
|
||||
modelSizeComboBox.setEnabled(
|
||||
modelSizeComboBox.getModel().getSize() > 1 && !llamaServerAgent.isServerRunning());
|
||||
});
|
||||
return comboBox;
|
||||
}
|
||||
|
||||
private static ModelSize initializeModelSizes(
|
||||
LlamaModel llamaModel,
|
||||
HuggingFaceModel hfm,
|
||||
DefaultComboBoxModel<ModelSize> modelSizeComboBoxModel) {
|
||||
var modelSizes = llamaModel.getSortedUniqueModelSizes();
|
||||
modelSizeComboBoxModel.removeAllElements();
|
||||
modelSizeComboBoxModel.addAll(modelSizes);
|
||||
var selectedModelSize = modelSizes.stream()
|
||||
.filter(ms -> ms.size() == hfm.getParameterSize())
|
||||
.findFirst().orElse(modelSizes.get(0));
|
||||
modelSizeComboBoxModel.setSelectedItem(selectedModelSize);
|
||||
return selectedModelSize;
|
||||
}
|
||||
|
||||
private ComboBox<ModelSize> createModelSizeComboBox(
|
||||
EnumComboBoxModel<LlamaModel> llamaModelComboBoxModel,
|
||||
DefaultComboBoxModel<ModelSize> modelSizeComboBoxModel,
|
||||
LlamaServerAgent llamaServerAgent,
|
||||
DefaultComboBoxModel<HuggingFaceModel> huggingFaceComboBoxModel) {
|
||||
var comboBox = new ComboBox<>(modelSizeComboBoxModel);
|
||||
comboBox.setPreferredSize(modelComboBox.getPreferredSize());
|
||||
comboBox.setSelectedItem(modelSizeComboBoxModel.getSelectedItem());
|
||||
comboBox.setEnabled(
|
||||
modelSizeComboBoxModel.getSize() > 1 && !llamaServerAgent.isServerRunning());
|
||||
comboBox.addItemListener(e -> {
|
||||
var selectedModel = llamaModelComboBoxModel.getSelectedItem();
|
||||
var models = selectedModel.filterSelectedModelsBySize(
|
||||
(ModelSize) modelSizeComboBoxModel.getSelectedItem());
|
||||
comboBox.setEnabled(
|
||||
modelSizeComboBoxModel.getSize() > 1 && !llamaServerAgent.isServerRunning());
|
||||
if (!models.isEmpty()) {
|
||||
huggingFaceComboBoxModel.removeAllElements();
|
||||
huggingFaceComboBoxModel.addAll(models);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
package ee.carlrobert.codegpt.ui;
|
||||
|
||||
import static com.intellij.notification.NotificationType.INFORMATION;
|
||||
import static com.intellij.openapi.ui.Messages.CANCEL;
|
||||
import static com.intellij.openapi.ui.Messages.OK;
|
||||
import static ee.carlrobert.codegpt.Icons.Default;
|
||||
|
|
@ -10,6 +11,7 @@ import com.intellij.execution.ExecutionBundle;
|
|||
import com.intellij.notification.Notification;
|
||||
import com.intellij.notification.NotificationType;
|
||||
import com.intellij.notification.Notifications;
|
||||
import com.intellij.openapi.actionSystem.AnAction;
|
||||
import com.intellij.openapi.actionSystem.AnActionEvent;
|
||||
import com.intellij.openapi.project.Project;
|
||||
import com.intellij.openapi.ui.DoNotAskOption;
|
||||
|
|
@ -26,6 +28,7 @@ import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
|
|||
import ee.carlrobert.codegpt.util.EditorUtil;
|
||||
import java.awt.Point;
|
||||
import java.awt.event.MouseEvent;
|
||||
import java.util.Arrays;
|
||||
import javax.swing.JComponent;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
|
|
@ -37,30 +40,53 @@ public class OverlayUtil {
|
|||
private OverlayUtil() {
|
||||
}
|
||||
|
||||
public static Notification getDefaultNotification(String content, NotificationType type) {
|
||||
return new Notification(NOTIFICATION_GROUP_ID, "CodeGPT", content, type);
|
||||
public static Notification getDefaultNotification(
|
||||
@NotNull String content, @NotNull AnAction... actions) {
|
||||
return getDefaultNotification(content, INFORMATION, actions);
|
||||
}
|
||||
|
||||
public static Notification getStickyNotification(String content, NotificationType type) {
|
||||
return new Notification(NOTIFICATION_GROUP_STICKY_ID, "CodeGPT", content, type);
|
||||
}
|
||||
|
||||
public static Notification showNotification(String content) {
|
||||
return showNotification(content, NotificationType.INFORMATION);
|
||||
}
|
||||
|
||||
public static Notification showNotification(String content, NotificationType type) {
|
||||
var notification = getDefaultNotification(content, type);
|
||||
Notifications.Bus.notify(notification);
|
||||
public static Notification getDefaultNotification(
|
||||
@NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) {
|
||||
var notification = new Notification(NOTIFICATION_GROUP_ID, "CodeGPT", content, type);
|
||||
Arrays.asList(actions).forEach(notification::addAction);
|
||||
return notification;
|
||||
}
|
||||
|
||||
public static Notification stickyNotification(String content) {
|
||||
return stickyNotification(content, NotificationType.INFORMATION);
|
||||
public static Notification getStickyNotification(
|
||||
@NotNull String content, @NotNull AnAction... actions) {
|
||||
return getStickyNotification(content, INFORMATION, actions);
|
||||
}
|
||||
|
||||
public static Notification stickyNotification(String content, NotificationType type) {
|
||||
var notification = getStickyNotification(content, type);
|
||||
public static Notification getStickyNotification(
|
||||
@NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) {
|
||||
var notification = new Notification(NOTIFICATION_GROUP_STICKY_ID, "CodeGPT", content, type);
|
||||
Arrays.asList(actions).forEach(notification::addAction);
|
||||
return notification;
|
||||
}
|
||||
|
||||
public static Notification showNotification(
|
||||
@NotNull String content, @NotNull AnAction... actions) {
|
||||
return showNotification(content, INFORMATION, actions);
|
||||
}
|
||||
|
||||
public static Notification showNotification(
|
||||
@NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) {
|
||||
return notify(getDefaultNotification(content, type, actions));
|
||||
}
|
||||
|
||||
public static Notification stickyNotification(
|
||||
@NotNull String content, @NotNull AnAction... actions) {
|
||||
return stickyNotification(content, INFORMATION, actions);
|
||||
}
|
||||
|
||||
public static Notification stickyNotification(
|
||||
@NotNull String content, @NotNull NotificationType type, @NotNull AnAction... actions) {
|
||||
return notify(getStickyNotification(content, type, actions));
|
||||
}
|
||||
|
||||
public static @NotNull Notification notify(
|
||||
@NotNull Notification notification, @NotNull AnAction... actions) {
|
||||
Arrays.asList(actions).forEach(notification::addAction);
|
||||
Notifications.Bus.notify(notification);
|
||||
return notification;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
package ee.carlrobert.codegpt.actions
|
||||
|
||||
import com.intellij.notification.Notification
|
||||
import com.intellij.notification.NotificationAction.createSimpleExpiring
|
||||
import com.intellij.openapi.actionSystem.ActionManager
|
||||
import com.intellij.openapi.actionSystem.ActionUpdateThread
|
||||
import com.intellij.openapi.actionSystem.AnActionEvent
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.DumbAwareAction
|
||||
import ee.carlrobert.codegpt.CodeGPTBundle
|
||||
import ee.carlrobert.codegpt.completions.llama.LlamaModel.findByHuggingFaceModel
|
||||
import ee.carlrobert.codegpt.completions.llama.LlamaServerAgent
|
||||
import ee.carlrobert.codegpt.completions.llama.LlamaServerStartupParams
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings
|
||||
|
|
@ -19,6 +21,13 @@ import ee.carlrobert.codegpt.ui.OverlayUtil.showNotification
|
|||
import ee.carlrobert.codegpt.ui.OverlayUtil.stickyNotification
|
||||
import java.util.function.Consumer
|
||||
|
||||
private const val STARTING = "settingsConfigurable.service.llama.progress.startingServer"
|
||||
private const val RUNNING = "settingsConfigurable.service.llama.progress.serverRunning"
|
||||
private const val STOPPING = "settingsConfigurable.service.llama.progress.stoppingServer"
|
||||
private const val STOPPED = "settingsConfigurable.service.llama.progress.serverStopped"
|
||||
private const val START = "settingsConfigurable.service.llama.stopServer.opposite"
|
||||
private const val STOP = "settingsConfigurable.service.llama.startServer.opposite"
|
||||
|
||||
/**
|
||||
* Start or stop server (if selected model exists) showing notifications
|
||||
*/
|
||||
|
|
@ -27,13 +36,15 @@ abstract class LlamaServerToggleActions(
|
|||
) : DumbAwareAction() {
|
||||
companion object {
|
||||
fun expireOtherNotification(start: Boolean) {
|
||||
(ActionManager.getInstance().getAction(
|
||||
if (start) "statusbar.stopServer" else "statusbar.startServer"
|
||||
) as LlamaServerToggleActions).apply {
|
||||
getAction(start).apply {
|
||||
this.notification?.expire()
|
||||
this.notification = null
|
||||
}
|
||||
}
|
||||
|
||||
private fun getAction(start: Boolean) = ActionManager.getInstance().getAction(getId(start)) as LlamaServerToggleActions
|
||||
|
||||
private fun getId(start: Boolean) = if (start) "statusbar.stopServer" else "statusbar.startServer"
|
||||
}
|
||||
|
||||
var notification: Notification? = null
|
||||
|
|
@ -43,39 +54,63 @@ abstract class LlamaServerToggleActions(
|
|||
notification?.expire()
|
||||
expireOtherNotification(startServer)
|
||||
val llamaServerAgent = service<LlamaServerAgent>()
|
||||
val serverName = LlamaSettings.getInstance().state.huggingFaceModel.let { findByHuggingFaceModel(it).toString(it) }
|
||||
if (startServer) {
|
||||
notification = stickyNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.startingServer"))
|
||||
val serverProgressPanel = ServerProgressPanel()
|
||||
llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel)
|
||||
val settings = LlamaSettings.getInstance().state
|
||||
llamaServerAgent.startAgent(
|
||||
LlamaServerStartupParams(
|
||||
LlamaSettings.getInstance().actualModelPath,
|
||||
settings.contextSize,
|
||||
settings.threads,
|
||||
settings.serverPort,
|
||||
getAdditionalParametersList(settings.additionalParameters),
|
||||
getAdditionalParametersList(settings.additionalBuildParameters)
|
||||
),
|
||||
serverProgressPanel,
|
||||
{
|
||||
notification?.expire()
|
||||
notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverRunning"))
|
||||
},
|
||||
{
|
||||
Consumer<ServerProgressPanel> { _: ServerProgressPanel ->
|
||||
notification?.expire()
|
||||
notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverStopped"))
|
||||
}
|
||||
})
|
||||
start(serverName, llamaServerAgent)
|
||||
} else {
|
||||
notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.stoppingServer"))
|
||||
llamaServerAgent.stopAgent()
|
||||
notification?.expire()
|
||||
notification = showNotification(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverStopped"))
|
||||
stop(serverName, llamaServerAgent)
|
||||
}
|
||||
}
|
||||
|
||||
private fun start(serverName: String, llamaServerAgent: LlamaServerAgent) {
|
||||
notification = stickyNotification(formatMsg(STARTING, serverName),
|
||||
createSimpleExpiring(CodeGPTBundle.get(STOP)) { stop(serverName, llamaServerAgent) })
|
||||
val serverProgressPanel = ServerProgressPanel()
|
||||
llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel)
|
||||
val settings = LlamaSettings.getInstance().state
|
||||
llamaServerAgent.startAgent(
|
||||
LlamaServerStartupParams(
|
||||
LlamaSettings.getInstance().actualModelPath,
|
||||
settings.contextSize,
|
||||
settings.threads,
|
||||
settings.serverPort,
|
||||
getAdditionalParametersList(settings.additionalParameters),
|
||||
getAdditionalParametersList(settings.additionalBuildParameters)
|
||||
),
|
||||
serverProgressPanel,
|
||||
{
|
||||
notification?.expire()
|
||||
notification = notification(RUNNING, false, serverName, llamaServerAgent)
|
||||
},
|
||||
{
|
||||
Consumer<ServerProgressPanel> { _: ServerProgressPanel ->
|
||||
notification?.expire()
|
||||
notification = notification(STOPPED, true, serverName, llamaServerAgent)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private fun stop(serverName: String, llamaServerAgent: LlamaServerAgent) {
|
||||
notification = showNotification(formatMsg(STOPPING, serverName))
|
||||
llamaServerAgent.stopAgent()
|
||||
notification?.expire()
|
||||
notification = notification(STOPPED, true, serverName, llamaServerAgent)
|
||||
}
|
||||
|
||||
private fun notification(id: String, nextStart: Boolean, serverName: String, llamaServerAgent: LlamaServerAgent) =
|
||||
showNotification(formatMsg(id, serverName),
|
||||
createSimpleExpiring(CodeGPTBundle.get(if (nextStart) START else STOP)) {
|
||||
if (nextStart) start(serverName, llamaServerAgent) else stop(serverName, llamaServerAgent)
|
||||
})
|
||||
|
||||
// "Starting server..." -> "Starting server: CodeLlama 7B 4-bit ..."
|
||||
// "Stopped server" -> "Stopped server: CodeLlama 7B 4-bit"
|
||||
private fun formatMsg(id: String, serverName: String): String {
|
||||
val msg = CodeGPTBundle.get(id)
|
||||
val points = msg.endsWith("...")
|
||||
return msg.let { if (points) it.substringBeforeLast("...") else it } + ": " + serverName + (if (points) " ..." else "")
|
||||
}
|
||||
|
||||
override fun update(e: AnActionEvent) {
|
||||
val llamaRunnable = isRunnable(LlamaSettings.getInstance().state.huggingFaceModel)
|
||||
val serverRunning = llamaRunnable && service<LlamaServerAgent>().isServerRunning
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ class OllamaSettingsForm {
|
|||
NotificationType.ERROR
|
||||
)
|
||||
} else {
|
||||
OverlayUtil.showNotification(ex.message, NotificationType.ERROR)
|
||||
OverlayUtil.showNotification(ex.message ?: "Error", NotificationType.ERROR)
|
||||
}
|
||||
disableModelComboBoxWithPlaceholder(DefaultComboBoxModel(arrayOf("Unable to load models")))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,11 +84,13 @@ settingsConfigurable.service.llama.additionalBuildParameters.comment=<html>Addit
|
|||
settingsConfigurable.service.llama.baseHost.label=Base host:
|
||||
settingsConfigurable.service.llama.baseHost.comment=URL to existing LLama server
|
||||
settingsConfigurable.service.llama.startServer.label=Start server
|
||||
settingsConfigurable.service.llama.startServer.opposite=Stop
|
||||
settingsConfigurable.service.llama.stopServer.label=Stop server
|
||||
settingsConfigurable.service.llama.stopServer.opposite=Start
|
||||
settingsConfigurable.service.llama.progress.serverRunning=Server running
|
||||
settingsConfigurable.service.llama.progress.serverStopped=Server stopped
|
||||
settingsConfigurable.service.llama.progress.stoppingServer=Stopping a server...
|
||||
settingsConfigurable.service.llama.progress.startingServer=Starting a server...
|
||||
settingsConfigurable.service.llama.progress.stoppingServer=Stopping server...
|
||||
settingsConfigurable.service.llama.progress.startingServer=Starting server...
|
||||
settingsConfigurable.service.llama.progress.downloadingModel.title=Downloading Model
|
||||
settingsConfigurable.service.llama.progress.downloadingModelIndicator.text=Downloading %s...
|
||||
settingsConfigurable.service.llama.overlay.modelNotDownloaded.text=Model is not downloaded
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue