diff --git a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java index 498ad797..e6c01889 100644 --- a/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java +++ b/src/main/java/ee/carlrobert/codegpt/actions/GenerateGitCommitMessageAction.java @@ -41,6 +41,7 @@ import org.jetbrains.annotations.NotNull; public class GenerateGitCommitMessageAction extends AnAction { + public static final int MAX_TOKEN_COUNT_WARNING = 4096; private final EncodingManager encodingManager; public GenerateGitCommitMessageAction() { @@ -54,12 +55,14 @@ public class GenerateGitCommitMessageAction extends AnAction { @Override public void update(@NotNull AnActionEvent event) { var selectedService = GeneralSettings.getCurrentState().getSelectedService(); - if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE) { + if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE + || selectedService == ServiceType.LLAMA_CPP) { var filesSelected = !getReferencedFilePaths(event).isEmpty(); var callAllowed = (selectedService == ServiceType.OPENAI && OpenAICredentialManager.getInstance().isCredentialSet()) || (selectedService == ServiceType.AZURE - && AzureCredentialsManager.getInstance().isCredentialSet()); + && AzureCredentialsManager.getInstance().isCredentialSet()) + || selectedService == ServiceType.LLAMA_CPP; event.getPresentation().setEnabled(callAllowed && filesSelected); event.getPresentation().setText(CodeGPTBundle.get(callAllowed ? "action.generateCommitMessage.title" @@ -79,8 +82,10 @@ public class GenerateGitCommitMessageAction extends AnAction { } var gitDiff = getGitDiff(project, getReferencedFilePaths(event)); + var tokenCount = encodingManager.countTokens(gitDiff); - if (tokenCount > 4096 && OverlayUtil.showTokenSoftLimitWarningDialog(tokenCount) != OK) { + if (tokenCount > MAX_TOKEN_COUNT_WARNING + && OverlayUtil.showTokenSoftLimitWarningDialog(tokenCount) != OK) { return; } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 3644a49a..7a832491 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -9,12 +9,16 @@ import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestProvider; import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails; +import ee.carlrobert.codegpt.completions.llama.LlamaModel; +import ee.carlrobert.codegpt.completions.llama.PromptTemplate; import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialManager; import ee.carlrobert.codegpt.settings.GeneralSettings; import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; import ee.carlrobert.codegpt.settings.service.azure.AzureSettings; +import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; +import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage; import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest; import ee.carlrobert.llm.completion.CompletionEventListener; @@ -104,6 +108,31 @@ public final class CompletionRequestService { if (selectedService == AZURE) { CompletionClientProvider.getAzureClient().getChatCompletionAsync(request, eventListener); } + + if (selectedService == LLAMA_CPP) { + var settings = LlamaSettings.getCurrentState(); + PromptTemplate promptTemplate; + if (settings.isRunLocalServer()) { + promptTemplate = settings.isUseCustomModel() + ? settings.getLocalModelPromptTemplate() + : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate(); + } else { + promptTemplate = settings.getRemoteModelPromptTemplate(); + } + var finalPrompt = promptTemplate.buildPrompt( + ConfigurationSettings.getCurrentState().getCommitMessagePrompt(), + prompt, List.of()); + var configuration = ConfigurationSettings.getCurrentState(); + CompletionClientProvider.getLlamaClient().getChatCompletionAsync( + new LlamaCompletionRequest.Builder(finalPrompt) + .setN_predict(configuration.getMaxTokens()) + .setTemperature(configuration.getTemperature()) + .setTop_k(settings.getTopK()) + .setTop_p(settings.getTopP()) + .setMin_p(settings.getMinP()) + .setRepeat_penalty(settings.getRepeatPenalty()) + .build(), eventListener); + } } public Optional getLookupCompletion(String prompt) {