From eddde39eff4f5e55250d7fc9fb4fceaa5534ae5f Mon Sep 17 00:00:00 2001 From: Carl-Robert Linnupuu Date: Sun, 22 Sep 2024 02:07:01 +0300 Subject: [PATCH] feat: enable Edit Code feature for all providers (closes #700, #698, #696) --- .../CompletionRequestProvider.java | 377 +++++++++++------- .../completions/CompletionRequestService.java | 72 +++- .../codegpt/completions/ConversationType.java | 1 + .../codegpt/completions/llama/LlamaModel.java | 11 +- .../editor/EditCodeSubmissionHandler.kt | 14 +- .../carlrobert/codegpt/ui/EditCodePopover.kt | 4 +- .../CompletionRequestProviderTest.kt | 17 +- .../testsupport/mixin/ShortcutsTestMixin.kt | 1 + 8 files changed, 308 insertions(+), 189 deletions(-) diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index 9886b420..cdb33519 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -4,6 +4,7 @@ import static ee.carlrobert.codegpt.completions.ConversationType.FIX_COMPILE_ERR import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY; import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent; import static java.lang.String.format; +import static java.util.Collections.emptyList; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; @@ -85,13 +86,6 @@ public class CompletionRequestProvider { public static final String EDIT_CODE_SYSTEM_PROMPT = getResourceContent("/prompts/edit-code.txt"); - private final EncodingManager encodingManager = EncodingManager.getInstance(); - private final Conversation conversation; - - public CompletionRequestProvider(Conversation conversation) { - this.conversation = conversation; - } - public static String getPromptWithContext(List referencedFiles, String userPrompt) { var includedFilesSettings = IncludedFilesSettings.getCurrentState(); @@ -127,7 +121,7 @@ public class CompletionRequestProvider { public static OpenAIChatCompletionRequest buildEditCodeRequest( String context, - String model) { + @Nullable String model) { return new OpenAIChatCompletionRequest.Builder( List.of( new OpenAIChatCompletionStandardMessage("system", EDIT_CODE_SYSTEM_PROMPT), @@ -138,131 +132,13 @@ public class CompletionRequestProvider { .build(); } - public static Request buildCustomOpenAICompletionRequest(String system, String context) { + public static Request buildCustomOpenAIChatCompletionRequest(CallParameters callParameters) { return buildCustomOpenAIChatCompletionRequest( ApplicationManager.getApplication().getService(CustomServiceSettings.class) .getState() .getChatCompletionSettings(), - List.of( - new OpenAIChatCompletionStandardMessage("system", system), - new OpenAIChatCompletionStandardMessage("user", context)), - true); - } - - public static Request buildCustomOpenAICompletionRequest(String context, String url, - Map headers, Map body, String credential) { - var usedSettings = new CustomServiceChatCompletionSettingsState(); - usedSettings.setBody(body); - usedSettings.setHeaders(headers); - usedSettings.setUrl(url); - return buildCustomOpenAIChatCompletionRequest( - usedSettings, - List.of(new OpenAIChatCompletionStandardMessage("user", context)), + CompletionRequestProvider.buildOpenAIMessages(callParameters), true, - credential); - } - - public static Request buildCustomOpenAILookupCompletionRequest(String context) { - return buildCustomOpenAIChatCompletionRequest( - ApplicationManager.getApplication().getService(CustomServiceSettings.class) - .getState() - .getChatCompletionSettings(), - List.of( - new OpenAIChatCompletionStandardMessage( - "system", - GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT), - new OpenAIChatCompletionStandardMessage("user", context)), - false); - } - - public static LlamaCompletionRequest buildLlamaLookupCompletionRequest(String context) { - return new LlamaCompletionRequest.Builder( - PromptTemplate.LLAMA.buildPrompt(GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT, context, List.of())) - .setStream(false) - .build(); - } - - public LlamaCompletionRequest buildLlamaCompletionRequest( - Message message, - ConversationType conversationType) { - var settings = LlamaSettings.getCurrentState(); - PromptTemplate promptTemplate; - if (settings.isRunLocalServer()) { - promptTemplate = settings.isUseCustomModel() - ? settings.getLocalModelPromptTemplate() - : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate(); - } else { - promptTemplate = settings.getRemoteModelPromptTemplate(); - } - - var systemPrompt = conversationType == FIX_COMPILE_ERRORS - ? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : PersonaSettings.getSystemPrompt(); - - var prompt = promptTemplate.buildPrompt( - systemPrompt, - message.getPrompt(), - conversation.getMessages()); - var configuration = ConfigurationSettings.getState(); - return new LlamaCompletionRequest.Builder(prompt) - .setN_predict(configuration.getMaxTokens()) - .setTemperature(configuration.getTemperature()) - .setTop_k(settings.getTopK()) - .setTop_p(settings.getTopP()) - .setMin_p(settings.getMinP()) - .setRepeat_penalty(settings.getRepeatPenalty()) - .setStop(promptTemplate.getStopTokens()) - .build(); - } - - public OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest( - @Nullable String model, - CallParameters callParameters) { - var configuration = ConfigurationSettings.getState(); - var requestBuilder = new OpenAIChatCompletionRequest.Builder( - buildOpenAIMessages(model, callParameters)) - .setModel(model) - .setMaxTokens(configuration.getMaxTokens()) - .setStream(true) - .setTemperature(configuration.getTemperature()); - if (callParameters.getMessage().isWebSearchIncluded()) { - // tri-state boolean - requestBuilder.setWebSearchIncluded(true); - } - var documentationDetails = - callParameters.getMessage().getDocumentationDetails(); - if (documentationDetails != null) { - var requestDocumentationDetails = new RequestDocumentationDetails(); - requestDocumentationDetails.setName(documentationDetails.getName()); - requestDocumentationDetails.setUrl(documentationDetails.getUrl()); - requestBuilder.setDocumentationDetails(requestDocumentationDetails); - } - return requestBuilder.build(); - } - - public GoogleCompletionRequest buildGoogleChatCompletionRequest( - @Nullable String model, - CallParameters callParameters) { - var configuration = ConfigurationSettings.getState(); - return new GoogleCompletionRequest.Builder(buildGoogleMessages(model, callParameters)) - .generationConfig(new GoogleGenerationConfig.Builder() - .maxOutputTokens(configuration.getMaxTokens()) - .temperature(configuration.getTemperature()).build()).build(); - } - - public Request buildCustomOpenAIChatCompletionRequest( - CustomServiceChatCompletionSettingsState settings, - CallParameters callParameters) { - return buildCustomOpenAIChatCompletionRequest( - settings, - buildOpenAIMessages(callParameters), - true); - } - - private static Request buildCustomOpenAIChatCompletionRequest( - CustomServiceChatCompletionSettingsState settings, - List messages, - boolean streamRequest) { - return buildCustomOpenAIChatCompletionRequest(settings, messages, streamRequest, CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY)); } @@ -271,7 +147,8 @@ public class CompletionRequestProvider { List messages, boolean streamRequest, String credential) { - var requestBuilder = new Request.Builder().url(requireNonNull(settings.getUrl()).trim()); + var requestBuilder = + new Request.Builder().url(requireNonNull(settings.getUrl()).trim()); for (var entry : settings.getHeaders().entrySet()) { String value = entry.getValue(); if (credential != null && value.contains("$CUSTOM_SERVICE_API_KEY")) { @@ -307,7 +184,168 @@ public class CompletionRequestProvider { } } - public ClaudeCompletionRequest buildAnthropicChatCompletionRequest( + public static Request buildCustomOpenAIEditCodeRequest(String input) { + return buildCustomOpenAIChatCompletionRequest( + ApplicationManager.getApplication().getService(CustomServiceSettings.class) + .getState() + .getChatCompletionSettings(), + List.of( + new OpenAIChatCompletionStandardMessage("system", EDIT_CODE_SYSTEM_PROMPT), + new OpenAIChatCompletionStandardMessage("user", input)), + true, + CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY)); + } + + public static Request buildCustomOpenAICompletionRequest(String system, String context) { + return buildCustomOpenAIChatCompletionRequest( + ApplicationManager.getApplication().getService(CustomServiceSettings.class) + .getState() + .getChatCompletionSettings(), + List.of( + new OpenAIChatCompletionStandardMessage("system", system), + new OpenAIChatCompletionStandardMessage("user", context)), + true, + CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY)); + } + + public static Request buildCustomOpenAICompletionRequest(String context, String url, + Map headers, Map body, String credential) { + var usedSettings = new CustomServiceChatCompletionSettingsState(); + usedSettings.setBody(body); + usedSettings.setHeaders(headers); + usedSettings.setUrl(url); + return buildCustomOpenAIChatCompletionRequest( + usedSettings, + List.of(new OpenAIChatCompletionStandardMessage("user", context)), + true, + credential); + } + + public static Request buildCustomOpenAILookupCompletionRequest(String context) { + return buildCustomOpenAIChatCompletionRequest( + ApplicationManager.getApplication().getService(CustomServiceSettings.class) + .getState() + .getChatCompletionSettings(), + List.of( + new OpenAIChatCompletionStandardMessage( + "system", + GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT), + new OpenAIChatCompletionStandardMessage("user", context)), + false, + CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY)); + } + + public static LlamaCompletionRequest buildLlamaLookupCompletionRequest(String context) { + return new LlamaCompletionRequest.Builder( + PromptTemplate.LLAMA.buildPrompt(GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT, context, List.of())) + .setStream(false) + .build(); + } + + public static LlamaCompletionRequest buildLlamaCompletionRequest( + Message message, + Conversation conversation, + ConversationType conversationType) { + var settings = LlamaSettings.getCurrentState(); + PromptTemplate promptTemplate; + if (settings.isRunLocalServer()) { + promptTemplate = settings.isUseCustomModel() + ? settings.getLocalModelPromptTemplate() + : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate(); + } else { + promptTemplate = settings.getRemoteModelPromptTemplate(); + } + + var systemPrompt = conversationType == FIX_COMPILE_ERRORS + ? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : PersonaSettings.getSystemPrompt(); + + var prompt = promptTemplate.buildPrompt( + systemPrompt, + message.getPrompt(), + conversation.getMessages()); + var configuration = ConfigurationSettings.getState(); + return new LlamaCompletionRequest.Builder(prompt) + .setN_predict(configuration.getMaxTokens()) + .setTemperature(configuration.getTemperature()) + .setTop_k(settings.getTopK()) + .setTop_p(settings.getTopP()) + .setMin_p(settings.getMinP()) + .setRepeat_penalty(settings.getRepeatPenalty()) + .setStop(promptTemplate.getStopTokens()) + .build(); + } + + public static LlamaCompletionRequest buildLlamaEditCodeRequest(String input) { + var settings = LlamaSettings.getCurrentState(); + PromptTemplate promptTemplate; + if (settings.isRunLocalServer()) { + promptTemplate = settings.isUseCustomModel() + ? settings.getLocalModelPromptTemplate() + : LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate(); + } else { + promptTemplate = settings.getRemoteModelPromptTemplate(); + } + + var prompt = promptTemplate.buildPrompt(EDIT_CODE_SYSTEM_PROMPT, input, emptyList()); + var configuration = ConfigurationSettings.getState(); + return new LlamaCompletionRequest.Builder(prompt) + .setN_predict(configuration.getMaxTokens()) + .setTemperature(configuration.getTemperature()) + .setTop_k(settings.getTopK()) + .setTop_p(settings.getTopP()) + .setMin_p(settings.getMinP()) + .setRepeat_penalty(settings.getRepeatPenalty()) + .setStop(promptTemplate.getStopTokens()) + .build(); + } + + public static OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest( + @Nullable String model, + CallParameters callParameters) { + var configuration = ConfigurationSettings.getState(); + var requestBuilder = new OpenAIChatCompletionRequest.Builder( + buildOpenAIMessages(model, callParameters)) + .setModel(model) + .setMaxTokens(configuration.getMaxTokens()) + .setStream(true) + .setTemperature(configuration.getTemperature()); + if (callParameters.getMessage().isWebSearchIncluded()) { + // tri-state boolean + requestBuilder.setWebSearchIncluded(true); + } + var documentationDetails = + callParameters.getMessage().getDocumentationDetails(); + if (documentationDetails != null) { + var requestDocumentationDetails = new RequestDocumentationDetails(); + requestDocumentationDetails.setName(documentationDetails.getName()); + requestDocumentationDetails.setUrl(documentationDetails.getUrl()); + requestBuilder.setDocumentationDetails(requestDocumentationDetails); + } + return requestBuilder.build(); + } + + public static GoogleCompletionRequest buildGoogleChatCompletionRequest( + @Nullable String model, + CallParameters callParameters) { + var configuration = ConfigurationSettings.getState(); + return new GoogleCompletionRequest.Builder(buildGoogleMessages(model, callParameters)) + .generationConfig(new GoogleGenerationConfig.Builder() + .maxOutputTokens(configuration.getMaxTokens()) + .temperature(configuration.getTemperature()).build()).build(); + } + + public static GoogleCompletionRequest buildGoogleEditCodeRequest(String input) { + var configuration = ConfigurationSettings.getState(); + return new GoogleCompletionRequest.Builder(List.of( + new GoogleCompletionContent("user", List.of(EDIT_CODE_SYSTEM_PROMPT)), + new GoogleCompletionContent("model", List.of("Understood.")), + new GoogleCompletionContent("user", List.of(input)))) + .generationConfig(new GoogleGenerationConfig.Builder() + .maxOutputTokens(configuration.getMaxTokens()) + .temperature(configuration.getTemperature()).build()).build(); + } + + public static ClaudeCompletionRequest buildAnthropicChatCompletionRequest( CallParameters callParameters) { var configuration = ConfigurationSettings.getState(); var settings = AnthropicSettings.getCurrentState(); @@ -316,7 +354,7 @@ public class CompletionRequestProvider { request.setMaxTokens(configuration.getMaxTokens()); request.setStream(true); request.setSystem(PersonaSettings.getSystemPrompt()); - List messages = conversation.getMessages().stream() + List messages = callParameters.getConversation().getMessages().stream() .filter(prevMessage -> prevMessage.getResponse() != null && !prevMessage.getResponse().isEmpty()) .flatMap(prevMessage -> Stream.of( @@ -339,7 +377,20 @@ public class CompletionRequestProvider { return request; } - public OllamaChatCompletionRequest buildOllamaChatCompletionRequest( + public static ClaudeCompletionRequest buildAnthropicEditCodeRequest( + String input) { + var configuration = ConfigurationSettings.getState(); + var settings = AnthropicSettings.getCurrentState(); + var request = new ClaudeCompletionRequest(); + request.setModel(settings.getModel()); + request.setMaxTokens(configuration.getMaxTokens()); + request.setStream(true); + request.setSystem(EDIT_CODE_SYSTEM_PROMPT); + request.setMessages(List.of(new ClaudeCompletionStandardMessage("user", input))); + return request; + } + + public static OllamaChatCompletionRequest buildOllamaChatCompletionRequest( CallParameters callParameters ) { var configuration = ConfigurationSettings.getState(); @@ -354,7 +405,24 @@ public class CompletionRequestProvider { .build(); } - private List buildOllamaMessages(CallParameters callParameters) { + public static OllamaChatCompletionRequest buildOllamaEditCodeRequest(String input) { + var configuration = ConfigurationSettings.getState(); + var settings = + ApplicationManager.getApplication().getService(OllamaSettings.class).getState(); + return new OllamaChatCompletionRequest + .Builder(settings.getModel(), List.of( + new OllamaChatCompletionMessage("system", EDIT_CODE_SYSTEM_PROMPT, null), + new OllamaChatCompletionMessage("user", input, null))) + .setStream(true) + .setOptions(new OllamaParameters.Builder() + .numPredict(configuration.getMaxTokens()) + .temperature((double) configuration.getTemperature()) + .build()) + .build(); + } + + private static List buildOllamaMessages( + CallParameters callParameters) { var message = callParameters.getMessage(); var messages = new ArrayList(); if (callParameters.getConversationType() == ConversationType.DEFAULT) { @@ -367,7 +435,7 @@ public class CompletionRequestProvider { ); } - for (var prevMessage : conversation.getMessages()) { + for (var prevMessage : callParameters.getConversation().getMessages()) { if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) { break; } @@ -406,7 +474,8 @@ public class CompletionRequestProvider { return messages; } - private List buildOpenAIMessages(CallParameters callParameters) { + private static List buildOpenAIMessages( + CallParameters callParameters) { var message = callParameters.getMessage(); var messages = new ArrayList(); if (callParameters.getConversationType() == ConversationType.DEFAULT) { @@ -425,7 +494,7 @@ public class CompletionRequestProvider { new OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT)); } - for (var prevMessage : conversation.getMessages()) { + for (var prevMessage : callParameters.getConversation().getMessages()) { if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) { break; } @@ -463,7 +532,7 @@ public class CompletionRequestProvider { return messages; } - private List buildOpenAIMessages( + public static List buildOpenAIMessages( @Nullable String model, CallParameters callParameters) { var messages = buildOpenAIMessages(callParameters); @@ -472,6 +541,7 @@ public class CompletionRequestProvider { return messages; } + var encodingManager = EncodingManager.getInstance(); int totalUsage = messages.parallelStream() .mapToInt(encodingManager::countMessageTokens) .sum() + ConfigurationSettings.getState().getMaxTokens(); @@ -485,10 +555,14 @@ public class CompletionRequestProvider { } catch (NoSuchElementException ex) { return messages; } - return tryReducingMessagesOrThrow(messages, totalUsage, modelMaxTokens); + return tryReducingMessagesOrThrow( + messages, + callParameters.getConversation().isDiscardTokenLimit(), + totalUsage, + modelMaxTokens); } - private List buildGoogleMessages(CallParameters callParameters) { + private static List buildGoogleMessages(CallParameters callParameters) { var message = callParameters.getMessage(); var messages = new ArrayList(); // Gemini API does not support direct 'system' prompts: @@ -503,7 +577,7 @@ public class CompletionRequestProvider { messages.add(new GoogleCompletionContent("model", List.of("Understood."))); } - for (var prevMessage : conversation.getMessages()) { + for (var prevMessage : callParameters.getConversation().getMessages()) { if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) { break; } @@ -538,7 +612,7 @@ public class CompletionRequestProvider { return messages; } - private List buildGoogleMessages( + private static List buildGoogleMessages( @Nullable String model, CallParameters callParameters) { var messages = buildGoogleMessages(callParameters); @@ -547,6 +621,7 @@ public class CompletionRequestProvider { return messages; } + var encodingManager = EncodingManager.getInstance(); int totalUsage = messages.parallelStream() .mapToInt(message -> encodingManager.countMessageTokens(message.getRole(), String.join(",", message.getParts().stream().map(GoogleContentPart::getText).toList()))) @@ -561,19 +636,24 @@ public class CompletionRequestProvider { } catch (NoSuchElementException ex) { return messages; } - return tryReducingGoogleMessagesOrThrow(messages, totalUsage, modelMaxTokens); + return tryReducingGoogleMessagesOrThrow( + messages, + callParameters.getConversation().isDiscardTokenLimit(), + totalUsage, + modelMaxTokens); } - private List tryReducingMessagesOrThrow( + private static List tryReducingMessagesOrThrow( List messages, + boolean discardTokenLimit, int totalUsage, int modelMaxTokens) { if (!ConversationsState.getInstance().discardAllTokenLimits) { - if (!conversation.isDiscardTokenLimit()) { + if (!discardTokenLimit) { throw new TotalUsageExceededException(); } } - + var encodingManager = EncodingManager.getInstance(); // skip the system prompt for (int i = 1; i < messages.size(); i++) { if (totalUsage <= modelMaxTokens) { @@ -590,16 +670,17 @@ public class CompletionRequestProvider { return messages.stream().filter(Objects::nonNull).toList(); } - private List tryReducingGoogleMessagesOrThrow( + private static List tryReducingGoogleMessagesOrThrow( List messages, + boolean discardTokenLimit, int totalUsage, int modelMaxTokens) { if (!ConversationsState.getInstance().discardAllTokenLimits) { - if (!conversation.isDiscardTokenLimit()) { + if (!discardTokenLimit) { throw new TotalUsageExceededException(); } } - + var encodingManager = EncodingManager.getInstance(); // skip the system prompt for (int i = 1; i < messages.size(); i++) { if (totalUsage <= modelMaxTokens) { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java index 547de07c..4e49ce40 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestService.java @@ -6,6 +6,7 @@ import com.intellij.openapi.diagnostic.Logger; import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory; import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestProvider; import ee.carlrobert.codegpt.codecompletions.InfillRequest; +import ee.carlrobert.codegpt.actions.editor.EditCodeRequestParams; import ee.carlrobert.codegpt.completions.llama.LlamaModel; import ee.carlrobert.codegpt.completions.llama.PromptTemplate; import ee.carlrobert.codegpt.credentials.CredentialsStore; @@ -16,7 +17,6 @@ import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings; import ee.carlrobert.codegpt.settings.service.azure.AzureSettings; import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings; -import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings; import ee.carlrobert.codegpt.settings.service.google.GoogleSettings; import ee.carlrobert.codegpt.settings.service.google.GoogleSettingsState; import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; @@ -83,10 +83,9 @@ public final class CompletionRequestService { CallParameters callParameters, CompletionEventListener eventListener) { var application = ApplicationManager.getApplication(); - var requestProvider = new CompletionRequestProvider(callParameters.getConversation()); return switch (GeneralSettings.getSelectedService()) { case CODEGPT -> CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync( - requestProvider.buildOpenAIChatCompletionRequest( + CompletionRequestProvider.buildOpenAIChatCompletionRequest( application.getService(CodeGPTServiceSettings.class) .getState() .getChatCompletionSettings() @@ -95,35 +94,32 @@ public final class CompletionRequestService { eventListener ); case OPENAI -> CompletionClientProvider.getOpenAIClient().getChatCompletionAsync( - requestProvider.buildOpenAIChatCompletionRequest( + CompletionRequestProvider.buildOpenAIChatCompletionRequest( OpenAISettings.getCurrentState().getModel(), callParameters), eventListener); case CUSTOM_OPENAI -> getCustomOpenAIChatCompletionAsync( - requestProvider.buildCustomOpenAIChatCompletionRequest( - application.getService(CustomServiceSettings.class) - .getState() - .getChatCompletionSettings(), - callParameters), + CompletionRequestProvider.buildCustomOpenAIChatCompletionRequest(callParameters), eventListener); case ANTHROPIC -> CompletionClientProvider.getClaudeClient().getCompletionAsync( - requestProvider.buildAnthropicChatCompletionRequest(callParameters), + CompletionRequestProvider.buildAnthropicChatCompletionRequest(callParameters), eventListener); case AZURE -> CompletionClientProvider.getAzureClient().getChatCompletionAsync( - requestProvider.buildOpenAIChatCompletionRequest(null, callParameters), + CompletionRequestProvider.buildOpenAIChatCompletionRequest(null, callParameters), eventListener); case LLAMA_CPP -> CompletionClientProvider.getLlamaClient().getChatCompletionAsync( - requestProvider.buildLlamaCompletionRequest( + CompletionRequestProvider.buildLlamaCompletionRequest( callParameters.getMessage(), + callParameters.getConversation(), callParameters.getConversationType()), eventListener); case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync( - requestProvider.buildOllamaChatCompletionRequest(callParameters), + CompletionRequestProvider.buildOllamaChatCompletionRequest(callParameters), eventListener); case GOOGLE -> { var settings = application.getService(GoogleSettings.class).getState(); yield CompletionClientProvider.getGoogleClient().getChatCompletionAsync( - requestProvider.buildGoogleChatCompletionRequest( + CompletionRequestProvider.buildGoogleChatCompletionRequest( settings.getModel(), callParameters), settings.getModel(), @@ -132,6 +128,54 @@ public final class CompletionRequestService { }; } + public EventSource getEditCodeCompletionAsync( + EditCodeRequestParams params, + CompletionEventListener eventListener) { + var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText()); + return switch (GeneralSettings.getSelectedService()) { + case CODEGPT -> CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync( + CompletionRequestProvider.buildEditCodeRequest( + input, + ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class) + .getState() + .getChatCompletionSettings() + .getModel() + ), + eventListener + ); + case OPENAI -> CompletionClientProvider.getOpenAIClient().getChatCompletionAsync( + CompletionRequestProvider.buildEditCodeRequest( + input, + OpenAISettings.getCurrentState().getModel()), + eventListener); + case CUSTOM_OPENAI -> getCustomOpenAIChatCompletionAsync( + CompletionRequestProvider.buildCustomOpenAIEditCodeRequest(input), + eventListener); + case ANTHROPIC -> CompletionClientProvider.getClaudeClient().getCompletionAsync( + CompletionRequestProvider.buildAnthropicEditCodeRequest(input), + eventListener); + case AZURE -> CompletionClientProvider.getAzureClient().getChatCompletionAsync( + CompletionRequestProvider.buildEditCodeRequest(input, null), + eventListener); + case LLAMA_CPP -> CompletionClientProvider.getLlamaClient().getChatCompletionAsync( + CompletionRequestProvider.buildLlamaEditCodeRequest(input), + eventListener); + case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync( + CompletionRequestProvider.buildOllamaEditCodeRequest(input), + eventListener); + case GOOGLE -> { + var model = + ApplicationManager.getApplication().getService(GoogleSettings.class) + .getState() + .getModel(); + yield CompletionClientProvider.getGoogleClient().getChatCompletionAsync( + CompletionRequestProvider.buildGoogleEditCodeRequest(input), + model, + eventListener); + } + }; + } + public EventSource getCodeCompletionAsync( InfillRequest requestDetails, CompletionEventListener eventListener) { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/ConversationType.java b/src/main/java/ee/carlrobert/codegpt/completions/ConversationType.java index efff0bbf..622c7fc2 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/ConversationType.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/ConversationType.java @@ -7,4 +7,5 @@ public enum ConversationType { FIX_COMPILE_ERRORS, MULTI_FILE, INLINE_COMPLETION, + EDIT_CODE } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java index e7e4da00..f7441ea2 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java @@ -172,11 +172,16 @@ public enum LlamaModel { HuggingFaceModel.CODE_QWEN_1_5_7B_Q6_K)), CODE_QWEN2_5_CODER( "CodeQwen2.5 Coder", """ - Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).\ + Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models \ + (formerly known as CodeQwen). It brings the following improvements upon CodeQwen1.5: - - Significantly improvements in code generation, code reasoning and code fixing. Base on the strong Qwen2.5, we scale up the training tokens into 5.5 trillion including source code, text-code grounding, Synthetic data, etc. - - A more comprehensive foundation for real-world applications such as Code Agents. Not only enhancing coding capabilities but also maintaining its strengths in mathematics and general competencies. + - Significantly improvements in code generation, code reasoning and code fixing. \ + Base on the strong Qwen2.5, we scale up the training tokens into 5.5 trillion including \ + source code, text-code grounding, Synthetic data, etc. + - A more comprehensive foundation for real-world applications such as Code Agents. \ + Not only enhancing coding capabilities but also maintaining its strengths in \ + mathematics and general competencies. - Long-context Support up to 128K tokens. """, PromptTemplate.CODE_QWEN, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt b/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt index 60215202..bef9f1b1 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/actions/editor/EditCodeSubmissionHandler.kt @@ -9,12 +9,12 @@ import com.intellij.openapi.util.TextRange import com.intellij.util.ui.AsyncProcessIcon import com.intellij.openapi.util.text.StringUtil import com.jetbrains.rd.util.AtomicReference -import ee.carlrobert.codegpt.completions.CompletionClientProvider -import ee.carlrobert.codegpt.completions.CompletionRequestProvider -import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings +import ee.carlrobert.codegpt.completions.CompletionRequestService import ee.carlrobert.codegpt.ui.ObservableProperties import javax.swing.JButton +data class EditCodeRequestParams(val prompt: String, val selectedText: String) + class EditCodeSubmissionHandler( private val editor: Editor, private val observableProperties: ObservableProperties, @@ -43,12 +43,8 @@ class EditCodeSubmissionHandler( } runInEdt { editor.selectionModel.removeSelection() } - // TODO: Support other providers - CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync( - CompletionRequestProvider.buildEditCodeRequest( - "$userPrompt\n\n$selectedText", - service().state.chatCompletionSettings.model - ), + service().getEditCodeCompletionAsync( + EditCodeRequestParams(userPrompt, selectedText), EditCodeCompletionListener( editor, selectionTextRange, diff --git a/src/main/kotlin/ee/carlrobert/codegpt/ui/EditCodePopover.kt b/src/main/kotlin/ee/carlrobert/codegpt/ui/EditCodePopover.kt index d0a2e10e..3c40f478 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/ui/EditCodePopover.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/ui/EditCodePopover.kt @@ -21,7 +21,6 @@ import com.intellij.util.ui.JBUI import ee.carlrobert.codegpt.CodeGPTBundle import ee.carlrobert.codegpt.actions.editor.EditCodeSubmissionHandler import ee.carlrobert.codegpt.settings.GeneralSettings -import ee.carlrobert.codegpt.settings.service.ServiceType.CODEGPT import ee.carlrobert.codegpt.toolwindow.chat.ui.textarea.ModelComboBoxAction import ee.carlrobert.codegpt.util.ApplicationUtil import kotlinx.coroutines.CoroutineScope @@ -195,8 +194,7 @@ class EditCodePopover(private val editor: Editor) { ModelComboBoxAction( ApplicationUtil.findCurrentProject(), {}, - GeneralSettings.getSelectedService(), - listOf(CODEGPT) + GeneralSettings.getSelectedService() ) .createCustomComponent(ActionPlaces.UNKNOWN) ).align(AlignX.RIGHT) diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt index ee330f4a..3c8f033d 100644 --- a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt @@ -4,10 +4,8 @@ import com.intellij.openapi.components.service import ee.carlrobert.codegpt.conversations.Conversation import ee.carlrobert.codegpt.conversations.ConversationService import ee.carlrobert.codegpt.conversations.message.Message -import ee.carlrobert.codegpt.settings.GeneralSettings import ee.carlrobert.codegpt.settings.persona.DEFAULT_PROMPT import ee.carlrobert.codegpt.settings.persona.PersonaSettings -import ee.carlrobert.codegpt.settings.service.ServiceType import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel import org.assertj.core.api.Assertions.assertThat import org.assertj.core.groups.Tuple @@ -24,8 +22,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) - val request = CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( + val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest( OpenAIChatCompletionModel.GPT_3_5.code, CallParameters( conversation, @@ -54,8 +51,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) - val request = CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( + val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest( OpenAIChatCompletionModel.GPT_3_5.code, CallParameters( conversation, @@ -84,8 +80,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(firstMessage) conversation.addMessage(secondMessage) - val request = CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( + val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest( OpenAIChatCompletionModel.GPT_3_5.code, CallParameters( conversation, @@ -115,8 +110,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(remainingMessage) conversation.discardTokenLimits() - val request = CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( + val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest( OpenAIChatCompletionModel.GPT_3_5.code, CallParameters( conversation, @@ -142,8 +136,7 @@ class CompletionRequestProviderTest : IntegrationTest() { conversation.addMessage(createDummyMessage(1500)) assertThrows(TotalUsageExceededException::class.java) { - CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( + CompletionRequestProvider.buildOpenAIChatCompletionRequest( OpenAIChatCompletionModel.GPT_3_5.code, CallParameters( conversation, diff --git a/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt b/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt index abe12273..d82cfc6b 100644 --- a/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt +++ b/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt @@ -47,6 +47,7 @@ interface ShortcutsTestMixin { GeneralSettings.getCurrentState().selectedService = ServiceType.LLAMA_CPP LlamaSettings.getCurrentState().serverPort = null LlamaSettings.getCurrentState().isCodeCompletionsEnabled = codeCompletionsEnabled + LlamaSettings.getCurrentState().huggingFaceModel = HuggingFaceModel.CODE_LLAMA_7B_Q4 } fun useOllamaService() {