From 8c986fd7de109d704e590639fc943a703a2af5be Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Tue, 12 Mar 2024 21:26:33 +0200 Subject: [PATCH] feat: support git commit message generation with custom openai and anthropic service (#390) --- .../GenerateGitCommitMessageAction.java | 101 +++++++++++++----- .../CompletionRequestProvider.java | 9 ++ .../completions/CompletionRequestService.java | 97 +++++++++++------ .../resources/messages/codegpt.properties | 2 +- 4 files changed, 148 insertions(+), 61 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java index 1bc1e26e..f65a7692 100644 --- a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java +++ b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java @@ -2,6 +2,12 @@ package ee.carlrobert.codegpt.actions; import static com.intellij.openapi.ui.Messages.OK; import static com.intellij.util.ObjectUtils.tryCast; +import static ee.carlrobert.codegpt.settings.service.ServiceType.ANTHROPIC; +import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE; +import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI; +import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP; +import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI; +import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; @@ -16,10 +22,13 @@ import com.intellij.openapi.editor.Document; import com.intellij.openapi.editor.Editor; import com.intellij.openapi.editor.ex.EditorEx; import com.intellij.openapi.project.Project; +import com.intellij.openapi.vcs.FilePath; import com.intellij.openapi.vcs.VcsDataKeys; +import com.intellij.openapi.vcs.changes.Change; import com.intellij.openapi.vcs.changes.ui.ChangesBrowserBase; import com.intellij.openapi.vcs.changes.ui.CommitDialogChangesBrowser; import com.intellij.openapi.vcs.ui.CommitMessage; +import com.intellij.openapi.vfs.VirtualFile; import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.EncodingManager; import ee.carlrobert.codegpt.Icons; @@ -35,8 +44,12 @@ import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Stream; import okhttp3.sse.EventSource; import org.jetbrains.annotations.NotNull; @@ -56,23 +69,26 @@ public class GenerateGitCommitMessageAction extends AnAction { @Override public void update(@NotNull AnActionEvent event) { var selectedService = GeneralSettings.getCurrentState().getSelectedService(); - if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE - || selectedService == ServiceType.LLAMA_CPP) { - var filesSelected = !getReferencedFilePaths(event).isEmpty(); - var callAllowed = (selectedService == ServiceType.OPENAI - && OpenAICredentialManager.getInstance().isCredentialSet()) - || (selectedService == ServiceType.AZURE - && AzureCredentialsManager.getInstance().isCredentialSet()) - || selectedService == ServiceType.LLAMA_CPP; - event.getPresentation().setEnabled(callAllowed && filesSelected); - event.getPresentation().setText(CodeGPTBundle.get(callAllowed - ? "action.generateCommitMessage.title" - : "action.generateCommitMessage.missingCredentials")); - } else { - event.getPresentation().setEnabled(false); - event.getPresentation() - .setText(CodeGPTBundle.get("action.generateCommitMessage.serviceWarning")); + if (selectedService == YOU) { + event.getPresentation().setVisible(false); + return; } + + var includedChangesFilePaths = getIncludedChangesFilePaths(event); + var includedUnversionedChangesFilePaths = getIncludedUnversionedFilePaths(event); + var filesSelected = + !includedChangesFilePaths.isEmpty() || !includedUnversionedChangesFilePaths.isEmpty(); + var callAllowed = isCallAllowed(selectedService); + event.getPresentation().setEnabled(callAllowed && filesSelected); + event.getPresentation().setText(CodeGPTBundle.get(callAllowed + ? "action.generateCommitMessage.title" + : "action.generateCommitMessage.missingCredentials")); + } + + private boolean isCallAllowed(ServiceType serviceType) { + return (serviceType == OPENAI && OpenAICredentialManager.getInstance().isCredentialSet()) + || (serviceType == AZURE && AzureCredentialsManager.getInstance().isCredentialSet()) + || List.of(LLAMA_CPP, ANTHROPIC, CUSTOM_OPENAI).contains(serviceType); } @Override @@ -82,7 +98,10 @@ public class GenerateGitCommitMessageAction extends AnAction { return; } - var gitDiff = getGitDiff(project, getReferencedFilePaths(event)); + var gitDiff = getGitDiff( + project, + getIncludedChangesFilePaths(event), + getIncludedUnversionedFilePaths(event)); var tokenCount = encodingManager.countTokens(gitDiff); if (tokenCount > MAX_TOKEN_COUNT_WARNING @@ -130,16 +149,31 @@ public class GenerateGitCommitMessageAction extends AnAction { return commitMessage != null ? commitMessage.getEditorField().getEditor() : null; } - private String getGitDiff(Project project, List filePaths) { - var process = createGitDiffProcess(project.getBasePath(), filePaths); - var reader = new BufferedReader(new InputStreamReader(process.getInputStream())); - return reader.lines().collect(joining("\n")); + private String getGitDiff( + Project project, + List includedChangesFilePaths, + List includedUnversionedFilePaths) { + return Stream.of( + new AbstractMap.SimpleEntry<>(includedChangesFilePaths, true), + new AbstractMap.SimpleEntry<>(includedUnversionedFilePaths, false)) + .filter(entry -> !entry.getKey().isEmpty()) + .map(entry -> { + var process = + createGitDiffProcess(project.getBasePath(), entry.getKey(), entry.getValue()); + return new BufferedReader(new InputStreamReader(process.getInputStream())) + .lines() + .collect(joining("\n")); + }) + .collect(joining("\n")); } - private Process createGitDiffProcess(String projectPath, List filePaths) { + private Process createGitDiffProcess(String projectPath, List filePaths, boolean cached) { var command = new ArrayList(); command.add("git"); command.add("diff"); + if (cached) { + command.add("--cached"); + } command.addAll(filePaths); var processBuilder = new ProcessBuilder(command); @@ -151,16 +185,29 @@ public class GenerateGitCommitMessageAction extends AnAction { } } - private @NotNull List getReferencedFilePaths(AnActionEvent event) { + private @NotNull List getFilePaths( + AnActionEvent event, + Function> extractor) { var changesBrowserBase = event.getData(ChangesBrowserBase.DATA_KEY); if (changesBrowserBase == null) { return List.of(); } - var includedChanges = ((CommitDialogChangesBrowser) changesBrowserBase).getIncludedChanges(); - return includedChanges.stream() - .filter(item -> item.getVirtualFile() != null) - .map(item -> item.getVirtualFile().getPath()) + return extractor.apply((CommitDialogChangesBrowser) changesBrowserBase) + .map(obj -> obj instanceof Change + ? ((Change) obj).getVirtualFile() + : ((FilePath) obj).getVirtualFile()) + .filter(Objects::nonNull) + .map(VirtualFile::getPath) + .distinct() .collect(toList()); } + + private @NotNull List getIncludedChangesFilePaths(AnActionEvent event) { + return getFilePaths(event, browser -> browser.getIncludedChanges().stream()); + } + + private @NotNull List getIncludedUnversionedFilePaths(AnActionEvent event) { + return getFilePaths(event, browser -> browser.getIncludedUnversionedFiles().stream()); + } } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index dc14abe6..b2e4a826 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -104,6 +104,15 @@ public class CompletionRequestProvider { .build(); } + public static Request buildCustomOpenAICompletionRequest(String system, String context) { + return buildCustomOpenAIChatCompletionRequest( + CustomServiceSettings.getCurrentState(), + List.of( + new OpenAIChatCompletionMessage("system", system), + new OpenAIChatCompletionMessage("user", context)), + true); + } + public static Request buildCustomOpenAILookupCompletionRequest(String context) { return buildCustomOpenAIChatCompletionRequest( CustomServiceSettings.getCurrentState(), diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 02964cb1..cdd42790 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -8,6 +8,7 @@ import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; +import com.intellij.openapi.diagnostic.Logger; import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestProvider; import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails; import ee.carlrobert.codegpt.completions.llama.LlamaModel; @@ -16,11 +17,14 @@ import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialManager; import ee.carlrobert.codegpt.settings.GeneralSettings; import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; +import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings; import ee.carlrobert.codegpt.settings.service.azure.AzureSettings; import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings; import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; import ee.carlrobert.llm.client.DeserializationUtil; +import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest; +import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequestMessage; import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener; import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage; @@ -37,6 +41,8 @@ import okhttp3.sse.EventSources; @Service public final class CompletionRequestService { + private static final Logger LOG = Logger.getInstance(CompletionRequestService.class); + private CompletionRequestService() { } @@ -122,43 +128,68 @@ public final class CompletionRequestService { public void generateCommitMessageAsync( String prompt, CompletionEventListener eventListener) { - var request = new OpenAIChatCompletionRequest.Builder(List.of( - new OpenAIChatCompletionMessage("system", - ConfigurationSettings.getCurrentState().getCommitMessagePrompt()), + var configuration = ConfigurationSettings.getCurrentState(); + var commitMessagePrompt = configuration.getCommitMessagePrompt(); + var openaiRequest = new OpenAIChatCompletionRequest.Builder(List.of( + new OpenAIChatCompletionMessage("system", commitMessagePrompt), new OpenAIChatCompletionMessage("user", prompt))) .setModel(OpenAISettings.getCurrentState().getModel()) .build(); var selectedService = GeneralSettings.getCurrentState().getSelectedService(); - if (selectedService == OPENAI) { - CompletionClientProvider.getOpenAIClient().getChatCompletionAsync(request, eventListener); - } - if (selectedService == AZURE) { - CompletionClientProvider.getAzureClient().getChatCompletionAsync(request, eventListener); - } - - if (selectedService == LLAMA_CPP) { - var settings = LlamaSettings.getCurrentState(); - PromptTemplate promptTemplate; - if (settings.isRunLocalServer()) { - promptTemplate = settings.isUseCustomModel() - ? settings.getLocalModelPromptTemplate() - : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate(); - } else { - promptTemplate = settings.getRemoteModelPromptTemplate(); - } - var finalPrompt = promptTemplate.buildPrompt( - ConfigurationSettings.getCurrentState().getCommitMessagePrompt(), - prompt, List.of()); - var configuration = ConfigurationSettings.getCurrentState(); - CompletionClientProvider.getLlamaClient().getChatCompletionAsync( - new LlamaCompletionRequest.Builder(finalPrompt) - .setN_predict(configuration.getMaxTokens()) - .setTemperature(configuration.getTemperature()) - .setTop_k(settings.getTopK()) - .setTop_p(settings.getTopP()) - .setMin_p(settings.getMinP()) - .setRepeat_penalty(settings.getRepeatPenalty()) - .build(), eventListener); + switch (selectedService) { + case OPENAI: + CompletionClientProvider.getOpenAIClient() + .getChatCompletionAsync(openaiRequest, eventListener); + break; + case CUSTOM_OPENAI: + var httpClient = CompletionClientProvider.getDefaultClientBuilder().build(); + EventSources.createFactory(httpClient).newEventSource( + CompletionRequestProvider.buildCustomOpenAICompletionRequest( + commitMessagePrompt, + prompt), + new OpenAIChatCompletionEventSourceListener(eventListener)); + break; + case ANTHROPIC: + var anthropicSettings = AnthropicSettings.getCurrentState(); + var claudeRequest = new ClaudeCompletionRequest(); + claudeRequest.setSystem(commitMessagePrompt); + claudeRequest.setStream(true); + claudeRequest.setMaxTokens(configuration.getMaxTokens()); + claudeRequest.setModel(anthropicSettings.getModel()); + claudeRequest.setMessages( + List.of(new ClaudeCompletionRequestMessage("user", prompt))); + CompletionClientProvider.getClaudeClient() + .getCompletionAsync(claudeRequest, eventListener); + break; + case AZURE: + CompletionClientProvider.getAzureClient() + .getChatCompletionAsync(openaiRequest, eventListener); + break; + case LLAMA_CPP: + var settings = LlamaSettings.getCurrentState(); + PromptTemplate promptTemplate; + if (settings.isRunLocalServer()) { + promptTemplate = settings.isUseCustomModel() + ? settings.getLocalModelPromptTemplate() + : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()) + .getPromptTemplate(); + } else { + promptTemplate = settings.getRemoteModelPromptTemplate(); + } + var finalPrompt = promptTemplate.buildPrompt(commitMessagePrompt, prompt, List.of()); + CompletionClientProvider.getLlamaClient().getChatCompletionAsync( + new LlamaCompletionRequest.Builder(finalPrompt) + .setN_predict(configuration.getMaxTokens()) + .setTemperature(configuration.getTemperature()) + .setTop_k(settings.getTopK()) + .setTop_p(settings.getTopP()) + .setMin_p(settings.getMinP()) + .setRepeat_penalty(settings.getRepeatPenalty()) + .build(), eventListener); + break; + default: + LOG.debug("Unknown service: {}", selectedService); + break; } } diff --git a/src/main/resources/messages/codegpt.properties b/src/main/resources/messages/codegpt.properties index 9f71715c..dc6c02ca 100644 --- a/src/main/resources/messages/codegpt.properties +++ b/src/main/resources/messages/codegpt.properties @@ -1,7 +1,7 @@ project.label=CodeGPT notification.group.name=CodeGPT notification group action.generateCommitMessage.title=Generate Message -action.generateCommitMessage.description=Generate commit message using OpenAI service +action.generateCommitMessage.description=Generate commit message action.generateCommitMessage.serviceWarning=Messages can only be generated with OpenAI or Azure service action.generateCommitMessage.missingCredentials=Credentials not provided action.includeFilesInContext.title=Include In Context...