mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-12 22:31:24 +00:00
feat: support git commit message generation with custom openai and anthropic service (#390)
This commit is contained in:
parent
9990c6a57b
commit
8c986fd7de
4 changed files with 148 additions and 61 deletions
|
|
@ -2,6 +2,12 @@ package ee.carlrobert.codegpt.actions;
|
|||
|
||||
import static com.intellij.openapi.ui.Messages.OK;
|
||||
import static com.intellij.util.ObjectUtils.tryCast;
|
||||
import static ee.carlrobert.codegpt.settings.service.ServiceType.ANTHROPIC;
|
||||
import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE;
|
||||
import static ee.carlrobert.codegpt.settings.service.ServiceType.CUSTOM_OPENAI;
|
||||
import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP;
|
||||
import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI;
|
||||
import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;
|
||||
import static java.util.stream.Collectors.joining;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
|
|
@ -16,10 +22,13 @@ import com.intellij.openapi.editor.Document;
|
|||
import com.intellij.openapi.editor.Editor;
|
||||
import com.intellij.openapi.editor.ex.EditorEx;
|
||||
import com.intellij.openapi.project.Project;
|
||||
import com.intellij.openapi.vcs.FilePath;
|
||||
import com.intellij.openapi.vcs.VcsDataKeys;
|
||||
import com.intellij.openapi.vcs.changes.Change;
|
||||
import com.intellij.openapi.vcs.changes.ui.ChangesBrowserBase;
|
||||
import com.intellij.openapi.vcs.changes.ui.CommitDialogChangesBrowser;
|
||||
import com.intellij.openapi.vcs.ui.CommitMessage;
|
||||
import com.intellij.openapi.vfs.VirtualFile;
|
||||
import ee.carlrobert.codegpt.CodeGPTBundle;
|
||||
import ee.carlrobert.codegpt.EncodingManager;
|
||||
import ee.carlrobert.codegpt.Icons;
|
||||
|
|
@ -35,8 +44,12 @@ import java.io.BufferedReader;
|
|||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Stream;
|
||||
import okhttp3.sse.EventSource;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
|
|
@ -56,23 +69,26 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
@Override
|
||||
public void update(@NotNull AnActionEvent event) {
|
||||
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
|
||||
if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE
|
||||
|| selectedService == ServiceType.LLAMA_CPP) {
|
||||
var filesSelected = !getReferencedFilePaths(event).isEmpty();
|
||||
var callAllowed = (selectedService == ServiceType.OPENAI
|
||||
&& OpenAICredentialManager.getInstance().isCredentialSet())
|
||||
|| (selectedService == ServiceType.AZURE
|
||||
&& AzureCredentialsManager.getInstance().isCredentialSet())
|
||||
|| selectedService == ServiceType.LLAMA_CPP;
|
||||
event.getPresentation().setEnabled(callAllowed && filesSelected);
|
||||
event.getPresentation().setText(CodeGPTBundle.get(callAllowed
|
||||
? "action.generateCommitMessage.title"
|
||||
: "action.generateCommitMessage.missingCredentials"));
|
||||
} else {
|
||||
event.getPresentation().setEnabled(false);
|
||||
event.getPresentation()
|
||||
.setText(CodeGPTBundle.get("action.generateCommitMessage.serviceWarning"));
|
||||
if (selectedService == YOU) {
|
||||
event.getPresentation().setVisible(false);
|
||||
return;
|
||||
}
|
||||
|
||||
var includedChangesFilePaths = getIncludedChangesFilePaths(event);
|
||||
var includedUnversionedChangesFilePaths = getIncludedUnversionedFilePaths(event);
|
||||
var filesSelected =
|
||||
!includedChangesFilePaths.isEmpty() || !includedUnversionedChangesFilePaths.isEmpty();
|
||||
var callAllowed = isCallAllowed(selectedService);
|
||||
event.getPresentation().setEnabled(callAllowed && filesSelected);
|
||||
event.getPresentation().setText(CodeGPTBundle.get(callAllowed
|
||||
? "action.generateCommitMessage.title"
|
||||
: "action.generateCommitMessage.missingCredentials"));
|
||||
}
|
||||
|
||||
private boolean isCallAllowed(ServiceType serviceType) {
|
||||
return (serviceType == OPENAI && OpenAICredentialManager.getInstance().isCredentialSet())
|
||||
|| (serviceType == AZURE && AzureCredentialsManager.getInstance().isCredentialSet())
|
||||
|| List.of(LLAMA_CPP, ANTHROPIC, CUSTOM_OPENAI).contains(serviceType);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -82,7 +98,10 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
return;
|
||||
}
|
||||
|
||||
var gitDiff = getGitDiff(project, getReferencedFilePaths(event));
|
||||
var gitDiff = getGitDiff(
|
||||
project,
|
||||
getIncludedChangesFilePaths(event),
|
||||
getIncludedUnversionedFilePaths(event));
|
||||
|
||||
var tokenCount = encodingManager.countTokens(gitDiff);
|
||||
if (tokenCount > MAX_TOKEN_COUNT_WARNING
|
||||
|
|
@ -130,16 +149,31 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
return commitMessage != null ? commitMessage.getEditorField().getEditor() : null;
|
||||
}
|
||||
|
||||
private String getGitDiff(Project project, List<String> filePaths) {
|
||||
var process = createGitDiffProcess(project.getBasePath(), filePaths);
|
||||
var reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
|
||||
return reader.lines().collect(joining("\n"));
|
||||
private String getGitDiff(
|
||||
Project project,
|
||||
List<String> includedChangesFilePaths,
|
||||
List<String> includedUnversionedFilePaths) {
|
||||
return Stream.of(
|
||||
new AbstractMap.SimpleEntry<>(includedChangesFilePaths, true),
|
||||
new AbstractMap.SimpleEntry<>(includedUnversionedFilePaths, false))
|
||||
.filter(entry -> !entry.getKey().isEmpty())
|
||||
.map(entry -> {
|
||||
var process =
|
||||
createGitDiffProcess(project.getBasePath(), entry.getKey(), entry.getValue());
|
||||
return new BufferedReader(new InputStreamReader(process.getInputStream()))
|
||||
.lines()
|
||||
.collect(joining("\n"));
|
||||
})
|
||||
.collect(joining("\n"));
|
||||
}
|
||||
|
||||
private Process createGitDiffProcess(String projectPath, List<String> filePaths) {
|
||||
private Process createGitDiffProcess(String projectPath, List<String> filePaths, boolean cached) {
|
||||
var command = new ArrayList<String>();
|
||||
command.add("git");
|
||||
command.add("diff");
|
||||
if (cached) {
|
||||
command.add("--cached");
|
||||
}
|
||||
command.addAll(filePaths);
|
||||
|
||||
var processBuilder = new ProcessBuilder(command);
|
||||
|
|
@ -151,16 +185,29 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
}
|
||||
}
|
||||
|
||||
private @NotNull List<String> getReferencedFilePaths(AnActionEvent event) {
|
||||
private @NotNull List<String> getFilePaths(
|
||||
AnActionEvent event,
|
||||
Function<CommitDialogChangesBrowser, Stream<?>> extractor) {
|
||||
var changesBrowserBase = event.getData(ChangesBrowserBase.DATA_KEY);
|
||||
if (changesBrowserBase == null) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
var includedChanges = ((CommitDialogChangesBrowser) changesBrowserBase).getIncludedChanges();
|
||||
return includedChanges.stream()
|
||||
.filter(item -> item.getVirtualFile() != null)
|
||||
.map(item -> item.getVirtualFile().getPath())
|
||||
return extractor.apply((CommitDialogChangesBrowser) changesBrowserBase)
|
||||
.map(obj -> obj instanceof Change
|
||||
? ((Change) obj).getVirtualFile()
|
||||
: ((FilePath) obj).getVirtualFile())
|
||||
.filter(Objects::nonNull)
|
||||
.map(VirtualFile::getPath)
|
||||
.distinct()
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
private @NotNull List<String> getIncludedChangesFilePaths(AnActionEvent event) {
|
||||
return getFilePaths(event, browser -> browser.getIncludedChanges().stream());
|
||||
}
|
||||
|
||||
private @NotNull List<String> getIncludedUnversionedFilePaths(AnActionEvent event) {
|
||||
return getFilePaths(event, browser -> browser.getIncludedUnversionedFiles().stream());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,6 +104,15 @@ public class CompletionRequestProvider {
|
|||
.build();
|
||||
}
|
||||
|
||||
public static Request buildCustomOpenAICompletionRequest(String system, String context) {
|
||||
return buildCustomOpenAIChatCompletionRequest(
|
||||
CustomServiceSettings.getCurrentState(),
|
||||
List.of(
|
||||
new OpenAIChatCompletionMessage("system", system),
|
||||
new OpenAIChatCompletionMessage("user", context)),
|
||||
true);
|
||||
}
|
||||
|
||||
public static Request buildCustomOpenAILookupCompletionRequest(String context) {
|
||||
return buildCustomOpenAIChatCompletionRequest(
|
||||
CustomServiceSettings.getCurrentState(),
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU;
|
|||
|
||||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.components.Service;
|
||||
import com.intellij.openapi.diagnostic.Logger;
|
||||
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestProvider;
|
||||
import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails;
|
||||
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
|
||||
|
|
@ -16,11 +17,14 @@ import ee.carlrobert.codegpt.credentials.AzureCredentialsManager;
|
|||
import ee.carlrobert.codegpt.credentials.OpenAICredentialManager;
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
|
||||
import ee.carlrobert.llm.client.DeserializationUtil;
|
||||
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
|
||||
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequestMessage;
|
||||
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
|
||||
|
|
@ -37,6 +41,8 @@ import okhttp3.sse.EventSources;
|
|||
@Service
|
||||
public final class CompletionRequestService {
|
||||
|
||||
private static final Logger LOG = Logger.getInstance(CompletionRequestService.class);
|
||||
|
||||
private CompletionRequestService() {
|
||||
}
|
||||
|
||||
|
|
@ -122,43 +128,68 @@ public final class CompletionRequestService {
|
|||
public void generateCommitMessageAsync(
|
||||
String prompt,
|
||||
CompletionEventListener<String> eventListener) {
|
||||
var request = new OpenAIChatCompletionRequest.Builder(List.of(
|
||||
new OpenAIChatCompletionMessage("system",
|
||||
ConfigurationSettings.getCurrentState().getCommitMessagePrompt()),
|
||||
var configuration = ConfigurationSettings.getCurrentState();
|
||||
var commitMessagePrompt = configuration.getCommitMessagePrompt();
|
||||
var openaiRequest = new OpenAIChatCompletionRequest.Builder(List.of(
|
||||
new OpenAIChatCompletionMessage("system", commitMessagePrompt),
|
||||
new OpenAIChatCompletionMessage("user", prompt)))
|
||||
.setModel(OpenAISettings.getCurrentState().getModel())
|
||||
.build();
|
||||
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
|
||||
if (selectedService == OPENAI) {
|
||||
CompletionClientProvider.getOpenAIClient().getChatCompletionAsync(request, eventListener);
|
||||
}
|
||||
if (selectedService == AZURE) {
|
||||
CompletionClientProvider.getAzureClient().getChatCompletionAsync(request, eventListener);
|
||||
}
|
||||
|
||||
if (selectedService == LLAMA_CPP) {
|
||||
var settings = LlamaSettings.getCurrentState();
|
||||
PromptTemplate promptTemplate;
|
||||
if (settings.isRunLocalServer()) {
|
||||
promptTemplate = settings.isUseCustomModel()
|
||||
? settings.getLocalModelPromptTemplate()
|
||||
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate();
|
||||
} else {
|
||||
promptTemplate = settings.getRemoteModelPromptTemplate();
|
||||
}
|
||||
var finalPrompt = promptTemplate.buildPrompt(
|
||||
ConfigurationSettings.getCurrentState().getCommitMessagePrompt(),
|
||||
prompt, List.of());
|
||||
var configuration = ConfigurationSettings.getCurrentState();
|
||||
CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
|
||||
new LlamaCompletionRequest.Builder(finalPrompt)
|
||||
.setN_predict(configuration.getMaxTokens())
|
||||
.setTemperature(configuration.getTemperature())
|
||||
.setTop_k(settings.getTopK())
|
||||
.setTop_p(settings.getTopP())
|
||||
.setMin_p(settings.getMinP())
|
||||
.setRepeat_penalty(settings.getRepeatPenalty())
|
||||
.build(), eventListener);
|
||||
switch (selectedService) {
|
||||
case OPENAI:
|
||||
CompletionClientProvider.getOpenAIClient()
|
||||
.getChatCompletionAsync(openaiRequest, eventListener);
|
||||
break;
|
||||
case CUSTOM_OPENAI:
|
||||
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
|
||||
EventSources.createFactory(httpClient).newEventSource(
|
||||
CompletionRequestProvider.buildCustomOpenAICompletionRequest(
|
||||
commitMessagePrompt,
|
||||
prompt),
|
||||
new OpenAIChatCompletionEventSourceListener(eventListener));
|
||||
break;
|
||||
case ANTHROPIC:
|
||||
var anthropicSettings = AnthropicSettings.getCurrentState();
|
||||
var claudeRequest = new ClaudeCompletionRequest();
|
||||
claudeRequest.setSystem(commitMessagePrompt);
|
||||
claudeRequest.setStream(true);
|
||||
claudeRequest.setMaxTokens(configuration.getMaxTokens());
|
||||
claudeRequest.setModel(anthropicSettings.getModel());
|
||||
claudeRequest.setMessages(
|
||||
List.of(new ClaudeCompletionRequestMessage("user", prompt)));
|
||||
CompletionClientProvider.getClaudeClient()
|
||||
.getCompletionAsync(claudeRequest, eventListener);
|
||||
break;
|
||||
case AZURE:
|
||||
CompletionClientProvider.getAzureClient()
|
||||
.getChatCompletionAsync(openaiRequest, eventListener);
|
||||
break;
|
||||
case LLAMA_CPP:
|
||||
var settings = LlamaSettings.getCurrentState();
|
||||
PromptTemplate promptTemplate;
|
||||
if (settings.isRunLocalServer()) {
|
||||
promptTemplate = settings.isUseCustomModel()
|
||||
? settings.getLocalModelPromptTemplate()
|
||||
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel())
|
||||
.getPromptTemplate();
|
||||
} else {
|
||||
promptTemplate = settings.getRemoteModelPromptTemplate();
|
||||
}
|
||||
var finalPrompt = promptTemplate.buildPrompt(commitMessagePrompt, prompt, List.of());
|
||||
CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
|
||||
new LlamaCompletionRequest.Builder(finalPrompt)
|
||||
.setN_predict(configuration.getMaxTokens())
|
||||
.setTemperature(configuration.getTemperature())
|
||||
.setTop_k(settings.getTopK())
|
||||
.setTop_p(settings.getTopP())
|
||||
.setMin_p(settings.getMinP())
|
||||
.setRepeat_penalty(settings.getRepeatPenalty())
|
||||
.build(), eventListener);
|
||||
break;
|
||||
default:
|
||||
LOG.debug("Unknown service: {}", selectedService);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue