feat: use llama cpp for generation of git commit message. (#380)

* Enable remote llama cpp server for Windows.

* Mixtral instruct template was added.

* Use llama cpp for generation of git commit message.

* style fix
This commit is contained in:
Oleksii Maryshchenko 2024-02-22 11:23:22 +01:00 committed by GitHub
parent 6e1a116ed2
commit 9627bbda15
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 37 additions and 3 deletions

View file

@ -41,6 +41,7 @@ import org.jetbrains.annotations.NotNull;
public class GenerateGitCommitMessageAction extends AnAction {
public static final int MAX_TOKEN_COUNT_WARNING = 4096;
private final EncodingManager encodingManager;
public GenerateGitCommitMessageAction() {
@ -54,12 +55,14 @@ public class GenerateGitCommitMessageAction extends AnAction {
@Override
public void update(@NotNull AnActionEvent event) {
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE) {
if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE
|| selectedService == ServiceType.LLAMA_CPP) {
var filesSelected = !getReferencedFilePaths(event).isEmpty();
var callAllowed = (selectedService == ServiceType.OPENAI
&& OpenAICredentialManager.getInstance().isCredentialSet())
|| (selectedService == ServiceType.AZURE
&& AzureCredentialsManager.getInstance().isCredentialSet());
&& AzureCredentialsManager.getInstance().isCredentialSet())
|| selectedService == ServiceType.LLAMA_CPP;
event.getPresentation().setEnabled(callAllowed && filesSelected);
event.getPresentation().setText(CodeGPTBundle.get(callAllowed
? "action.generateCommitMessage.title"
@ -79,8 +82,10 @@ public class GenerateGitCommitMessageAction extends AnAction {
}
var gitDiff = getGitDiff(project, getReferencedFilePaths(event));
var tokenCount = encodingManager.countTokens(gitDiff);
if (tokenCount > 4096 && OverlayUtil.showTokenSoftLimitWarningDialog(tokenCount) != OK) {
if (tokenCount > MAX_TOKEN_COUNT_WARNING
&& OverlayUtil.showTokenSoftLimitWarningDialog(tokenCount) != OK) {
return;
}

View file

@ -9,12 +9,16 @@ import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestProvider;
import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails;
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
import ee.carlrobert.codegpt.completions.llama.PromptTemplate;
import ee.carlrobert.codegpt.credentials.AzureCredentialsManager;
import ee.carlrobert.codegpt.credentials.OpenAICredentialManager;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
import ee.carlrobert.llm.completion.CompletionEventListener;
@ -104,6 +108,31 @@ public final class CompletionRequestService {
if (selectedService == AZURE) {
CompletionClientProvider.getAzureClient().getChatCompletionAsync(request, eventListener);
}
if (selectedService == LLAMA_CPP) {
var settings = LlamaSettings.getCurrentState();
PromptTemplate promptTemplate;
if (settings.isRunLocalServer()) {
promptTemplate = settings.isUseCustomModel()
? settings.getLocalModelPromptTemplate()
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate();
} else {
promptTemplate = settings.getRemoteModelPromptTemplate();
}
var finalPrompt = promptTemplate.buildPrompt(
ConfigurationSettings.getCurrentState().getCommitMessagePrompt(),
prompt, List.of());
var configuration = ConfigurationSettings.getCurrentState();
CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
new LlamaCompletionRequest.Builder(finalPrompt)
.setN_predict(configuration.getMaxTokens())
.setTemperature(configuration.getTemperature())
.setTop_k(settings.getTopK())
.setTop_p(settings.getTopP())
.setMin_p(settings.getMinP())
.setRepeat_penalty(settings.getRepeatPenalty())
.build(), eventListener);
}
}
public Optional<String> getLookupCompletion(String prompt) {