mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-11 04:50:31 +00:00
feat: use llama cpp for generation of git commit message. (#380)
* Enable remote llama cpp server for Windows. * Mixtral instruct template was added. * Use llama cpp for generation of git commit message. * style fix
This commit is contained in:
parent
6e1a116ed2
commit
9627bbda15
2 changed files with 37 additions and 3 deletions
|
|
@ -41,6 +41,7 @@ import org.jetbrains.annotations.NotNull;
|
|||
|
||||
public class GenerateGitCommitMessageAction extends AnAction {
|
||||
|
||||
public static final int MAX_TOKEN_COUNT_WARNING = 4096;
|
||||
private final EncodingManager encodingManager;
|
||||
|
||||
public GenerateGitCommitMessageAction() {
|
||||
|
|
@ -54,12 +55,14 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
@Override
|
||||
public void update(@NotNull AnActionEvent event) {
|
||||
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
|
||||
if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE) {
|
||||
if (selectedService == ServiceType.OPENAI || selectedService == ServiceType.AZURE
|
||||
|| selectedService == ServiceType.LLAMA_CPP) {
|
||||
var filesSelected = !getReferencedFilePaths(event).isEmpty();
|
||||
var callAllowed = (selectedService == ServiceType.OPENAI
|
||||
&& OpenAICredentialManager.getInstance().isCredentialSet())
|
||||
|| (selectedService == ServiceType.AZURE
|
||||
&& AzureCredentialsManager.getInstance().isCredentialSet());
|
||||
&& AzureCredentialsManager.getInstance().isCredentialSet())
|
||||
|| selectedService == ServiceType.LLAMA_CPP;
|
||||
event.getPresentation().setEnabled(callAllowed && filesSelected);
|
||||
event.getPresentation().setText(CodeGPTBundle.get(callAllowed
|
||||
? "action.generateCommitMessage.title"
|
||||
|
|
@ -79,8 +82,10 @@ public class GenerateGitCommitMessageAction extends AnAction {
|
|||
}
|
||||
|
||||
var gitDiff = getGitDiff(project, getReferencedFilePaths(event));
|
||||
|
||||
var tokenCount = encodingManager.countTokens(gitDiff);
|
||||
if (tokenCount > 4096 && OverlayUtil.showTokenSoftLimitWarningDialog(tokenCount) != OK) {
|
||||
if (tokenCount > MAX_TOKEN_COUNT_WARNING
|
||||
&& OverlayUtil.showTokenSoftLimitWarningDialog(tokenCount) != OK) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,12 +9,16 @@ import com.intellij.openapi.application.ApplicationManager;
|
|||
import com.intellij.openapi.components.Service;
|
||||
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestProvider;
|
||||
import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails;
|
||||
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
|
||||
import ee.carlrobert.codegpt.completions.llama.PromptTemplate;
|
||||
import ee.carlrobert.codegpt.credentials.AzureCredentialsManager;
|
||||
import ee.carlrobert.codegpt.credentials.OpenAICredentialManager;
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
|
||||
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
|
||||
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
|
||||
import ee.carlrobert.llm.completion.CompletionEventListener;
|
||||
|
|
@ -104,6 +108,31 @@ public final class CompletionRequestService {
|
|||
if (selectedService == AZURE) {
|
||||
CompletionClientProvider.getAzureClient().getChatCompletionAsync(request, eventListener);
|
||||
}
|
||||
|
||||
if (selectedService == LLAMA_CPP) {
|
||||
var settings = LlamaSettings.getCurrentState();
|
||||
PromptTemplate promptTemplate;
|
||||
if (settings.isRunLocalServer()) {
|
||||
promptTemplate = settings.isUseCustomModel()
|
||||
? settings.getLocalModelPromptTemplate()
|
||||
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate();
|
||||
} else {
|
||||
promptTemplate = settings.getRemoteModelPromptTemplate();
|
||||
}
|
||||
var finalPrompt = promptTemplate.buildPrompt(
|
||||
ConfigurationSettings.getCurrentState().getCommitMessagePrompt(),
|
||||
prompt, List.of());
|
||||
var configuration = ConfigurationSettings.getCurrentState();
|
||||
CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
|
||||
new LlamaCompletionRequest.Builder(finalPrompt)
|
||||
.setN_predict(configuration.getMaxTokens())
|
||||
.setTemperature(configuration.getTemperature())
|
||||
.setTop_k(settings.getTopK())
|
||||
.setTop_p(settings.getTopP())
|
||||
.setMin_p(settings.getMinP())
|
||||
.setRepeat_penalty(settings.getRepeatPenalty())
|
||||
.build(), eventListener);
|
||||
}
|
||||
}
|
||||
|
||||
public Optional<String> getLookupCompletion(String prompt) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue