feat: support git commit message generation with custom openai and anthropic service (#390)

This commit is contained in:
Carl-Robert Linnupuu 2024-03-12 21:26:33 +02:00
parent 9990c6a57b
commit 8c986fd7de
4 changed files with 148 additions and 61 deletions

View file

@ -2,6 +2,12 @@ package ee.carlrobert.codegpt.actions;
import static com.intellij.openapi.ui.Messages.OK;
import static com.intellij.util.ObjectUtils.tryCast;
import static ee.carlrobert.codegpt.settings.service.ServiceType.ANTHROPIC;
import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
@ -16,10 +22,13 @@ import com.intellij.openapi.editor.Document;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.editor.ex.EditorEx;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.vcs.FilePath;
import com.intellij.openapi.vcs.VcsDataKeys;
import com.intellij.openapi.vcs.changes.Change;
import com.intellij.openapi.vcs.changes.ui.ChangesBrowserBase;
import com.intellij.openapi.vcs.changes.ui.CommitDialogChangesBrowser;
import com.intellij.openapi.vcs.ui.CommitMessage;
import com.intellij.openapi.vfs.VirtualFile;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.Icons;
@ -35,8 +44,12 @@ import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Stream;
import okhttp3.sse.EventSource;
import org.jetbrains.annotations.NotNull;
@ -56,23 +69,26 @@ public class GenerateGitCommitMessageAction extends AnAction {
@Override
public void update(@NotNull AnActionEvent event) {
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE
|| selectedService == ServiceType.LLAMA_CPP) {
var filesSelected = !getReferencedFilePaths(event).isEmpty();
var callAllowed = (selectedService == ServiceType.OPENAI
&& OpenAICredentialManager.getInstance().isCredentialSet())
|| (selectedService == ServiceType.AZURE
&& AzureCredentialsManager.getInstance().isCredentialSet())
|| selectedService == ServiceType.LLAMA_CPP;
event.getPresentation().setEnabled(callAllowed && filesSelected);
event.getPresentation().setText(CodeGPTBundle.get(callAllowed
? "action.generateCommitMessage.title"
: "action.generateCommitMessage.missingCredentials"));
} else {
event.getPresentation().setEnabled(false);
event.getPresentation()
.setText(CodeGPTBundle.get("action.generateCommitMessage.serviceWarning"));
if (selectedService == YOU) {
event.getPresentation().setVisible(false);
return;
}
var includedChangesFilePaths = getIncludedChangesFilePaths(event);
var includedUnversionedChangesFilePaths = getIncludedUnversionedFilePaths(event);
var filesSelected =
!includedChangesFilePaths.isEmpty() || !includedUnversionedChangesFilePaths.isEmpty();
var callAllowed = isCallAllowed(selectedService);
event.getPresentation().setEnabled(callAllowed && filesSelected);
event.getPresentation().setText(CodeGPTBundle.get(callAllowed
? "action.generateCommitMessage.title"
: "action.generateCommitMessage.missingCredentials"));
}
private boolean isCallAllowed(ServiceType serviceType) {
return (serviceType == OPENAI && OpenAICredentialManager.getInstance().isCredentialSet())
|| (serviceType == AZURE && AzureCredentialsManager.getInstance().isCredentialSet())
|| List.of(LLAMA_CPP, ANTHROPIC, CUSTOM_OPENAI).contains(serviceType);
}
@Override
@ -82,7 +98,10 @@ public class GenerateGitCommitMessageAction extends AnAction {
return;
}
var gitDiff = getGitDiff(project, getReferencedFilePaths(event));
var gitDiff = getGitDiff(
project,
getIncludedChangesFilePaths(event),
getIncludedUnversionedFilePaths(event));
var tokenCount = encodingManager.countTokens(gitDiff);
if (tokenCount > MAX_TOKEN_COUNT_WARNING
@ -130,16 +149,31 @@ public class GenerateGitCommitMessageAction extends AnAction {
return commitMessage != null ? commitMessage.getEditorField().getEditor() : null;
}
private String getGitDiff(Project project, List<String> filePaths) {
var process = createGitDiffProcess(project.getBasePath(), filePaths);
var reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
return reader.lines().collect(joining("\n"));
private String getGitDiff(
Project project,
List<String> includedChangesFilePaths,
List<String> includedUnversionedFilePaths) {
return Stream.of(
new AbstractMap.SimpleEntry<>(includedChangesFilePaths, true),
new AbstractMap.SimpleEntry<>(includedUnversionedFilePaths, false))
.filter(entry -> !entry.getKey().isEmpty())
.map(entry -> {
var process =
createGitDiffProcess(project.getBasePath(), entry.getKey(), entry.getValue());
return new BufferedReader(new InputStreamReader(process.getInputStream()))
.lines()
.collect(joining("\n"));
})
.collect(joining("\n"));
}
private Process createGitDiffProcess(String projectPath, List<String> filePaths) {
private Process createGitDiffProcess(String projectPath, List<String> filePaths, boolean cached) {
var command = new ArrayList<String>();
command.add("git");
command.add("diff");
if (cached) {
command.add("--cached");
}
command.addAll(filePaths);
var processBuilder = new ProcessBuilder(command);
@ -151,16 +185,29 @@ public class GenerateGitCommitMessageAction extends AnAction {
}
}
private @NotNull List<String> getReferencedFilePaths(AnActionEvent event) {
private @NotNull List<String> getFilePaths(
AnActionEvent event,
Function<CommitDialogChangesBrowser, Stream<?>> extractor) {
var changesBrowserBase = event.getData(ChangesBrowserBase.DATA_KEY);
if (changesBrowserBase == null) {
return List.of();
}
var includedChanges = ((CommitDialogChangesBrowser) changesBrowserBase).getIncludedChanges();
return includedChanges.stream()
.filter(item -> item.getVirtualFile() != null)
.map(item -> item.getVirtualFile().getPath())
return extractor.apply((CommitDialogChangesBrowser) changesBrowserBase)
.map(obj -> obj instanceof Change
? ((Change) obj).getVirtualFile()
: ((FilePath) obj).getVirtualFile())
.filter(Objects::nonNull)
.map(VirtualFile::getPath)
.distinct()
.collect(toList());
}
private @NotNull List<String> getIncludedChangesFilePaths(AnActionEvent event) {
return getFilePaths(event, browser -> browser.getIncludedChanges().stream());
}
private @NotNull List<String> getIncludedUnversionedFilePaths(AnActionEvent event) {
return getFilePaths(event, browser -> browser.getIncludedUnversionedFiles().stream());
}
}