feat: enable Edit Code feature for all providers (closes #700, #698, #696)

This commit is contained in:
Carl-Robert Linnupuu 2024-09-22 02:07:01 +03:00
parent 417a13afe2
commit eddde39eff
8 changed files with 308 additions and 189 deletions

View file

@ -4,6 +4,7 @@ import static ee.carlrobert.codegpt.completions.ConversationType.FIX_COMPILE_ERR
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY;
import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
@ -85,13 +86,6 @@ public class CompletionRequestProvider {
public static final String EDIT_CODE_SYSTEM_PROMPT =
getResourceContent("/prompts/edit-code.txt");
private final EncodingManager encodingManager = EncodingManager.getInstance();
private final Conversation conversation;
public CompletionRequestProvider(Conversation conversation) {
this.conversation = conversation;
}
public static String getPromptWithContext(List<ReferencedFile> referencedFiles,
String userPrompt) {
var includedFilesSettings = IncludedFilesSettings.getCurrentState();
@ -127,7 +121,7 @@ public class CompletionRequestProvider {
public static OpenAIChatCompletionRequest buildEditCodeRequest(
String context,
String model) {
@Nullable String model) {
return new OpenAIChatCompletionRequest.Builder(
List.of(
new OpenAIChatCompletionStandardMessage("system", EDIT_CODE_SYSTEM_PROMPT),
@ -138,131 +132,13 @@ public class CompletionRequestProvider {
.build();
}
public static Request buildCustomOpenAICompletionRequest(String system, String context) {
public static Request buildCustomOpenAIChatCompletionRequest(CallParameters callParameters) {
return buildCustomOpenAIChatCompletionRequest(
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage("system", system),
new OpenAIChatCompletionStandardMessage("user", context)),
true);
}
public static Request buildCustomOpenAICompletionRequest(String context, String url,
Map<String, String> headers, Map<String, Object> body, String credential) {
var usedSettings = new CustomServiceChatCompletionSettingsState();
usedSettings.setBody(body);
usedSettings.setHeaders(headers);
usedSettings.setUrl(url);
return buildCustomOpenAIChatCompletionRequest(
usedSettings,
List.of(new OpenAIChatCompletionStandardMessage("user", context)),
CompletionRequestProvider.buildOpenAIMessages(callParameters),
true,
credential);
}
public static Request buildCustomOpenAILookupCompletionRequest(String context) {
return buildCustomOpenAIChatCompletionRequest(
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage(
"system",
GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT),
new OpenAIChatCompletionStandardMessage("user", context)),
false);
}
public static LlamaCompletionRequest buildLlamaLookupCompletionRequest(String context) {
return new LlamaCompletionRequest.Builder(
PromptTemplate.LLAMA.buildPrompt(GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT, context, List.of()))
.setStream(false)
.build();
}
public LlamaCompletionRequest buildLlamaCompletionRequest(
Message message,
ConversationType conversationType) {
var settings = LlamaSettings.getCurrentState();
PromptTemplate promptTemplate;
if (settings.isRunLocalServer()) {
promptTemplate = settings.isUseCustomModel()
? settings.getLocalModelPromptTemplate()
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate();
} else {
promptTemplate = settings.getRemoteModelPromptTemplate();
}
var systemPrompt = conversationType == FIX_COMPILE_ERRORS
? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : PersonaSettings.getSystemPrompt();
var prompt = promptTemplate.buildPrompt(
systemPrompt,
message.getPrompt(),
conversation.getMessages());
var configuration = ConfigurationSettings.getState();
return new LlamaCompletionRequest.Builder(prompt)
.setN_predict(configuration.getMaxTokens())
.setTemperature(configuration.getTemperature())
.setTop_k(settings.getTopK())
.setTop_p(settings.getTopP())
.setMin_p(settings.getMinP())
.setRepeat_penalty(settings.getRepeatPenalty())
.setStop(promptTemplate.getStopTokens())
.build();
}
public OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest(
@Nullable String model,
CallParameters callParameters) {
var configuration = ConfigurationSettings.getState();
var requestBuilder = new OpenAIChatCompletionRequest.Builder(
buildOpenAIMessages(model, callParameters))
.setModel(model)
.setMaxTokens(configuration.getMaxTokens())
.setStream(true)
.setTemperature(configuration.getTemperature());
if (callParameters.getMessage().isWebSearchIncluded()) {
// tri-state boolean
requestBuilder.setWebSearchIncluded(true);
}
var documentationDetails =
callParameters.getMessage().getDocumentationDetails();
if (documentationDetails != null) {
var requestDocumentationDetails = new RequestDocumentationDetails();
requestDocumentationDetails.setName(documentationDetails.getName());
requestDocumentationDetails.setUrl(documentationDetails.getUrl());
requestBuilder.setDocumentationDetails(requestDocumentationDetails);
}
return requestBuilder.build();
}
public GoogleCompletionRequest buildGoogleChatCompletionRequest(
@Nullable String model,
CallParameters callParameters) {
var configuration = ConfigurationSettings.getState();
return new GoogleCompletionRequest.Builder(buildGoogleMessages(model, callParameters))
.generationConfig(new GoogleGenerationConfig.Builder()
.maxOutputTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature()).build()).build();
}
public Request buildCustomOpenAIChatCompletionRequest(
CustomServiceChatCompletionSettingsState settings,
CallParameters callParameters) {
return buildCustomOpenAIChatCompletionRequest(
settings,
buildOpenAIMessages(callParameters),
true);
}
private static Request buildCustomOpenAIChatCompletionRequest(
CustomServiceChatCompletionSettingsState settings,
List<OpenAIChatCompletionMessage> messages,
boolean streamRequest) {
return buildCustomOpenAIChatCompletionRequest(settings, messages, streamRequest,
CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY));
}
@ -271,7 +147,8 @@ public class CompletionRequestProvider {
List<OpenAIChatCompletionMessage> messages,
boolean streamRequest,
String credential) {
var requestBuilder = new Request.Builder().url(requireNonNull(settings.getUrl()).trim());
var requestBuilder =
new Request.Builder().url(requireNonNull(settings.getUrl()).trim());
for (var entry : settings.getHeaders().entrySet()) {
String value = entry.getValue();
if (credential != null && value.contains("$CUSTOM_SERVICE_API_KEY")) {
@ -307,7 +184,168 @@ public class CompletionRequestProvider {
}
}
public ClaudeCompletionRequest buildAnthropicChatCompletionRequest(
public static Request buildCustomOpenAIEditCodeRequest(String input) {
return buildCustomOpenAIChatCompletionRequest(
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage("system", EDIT_CODE_SYSTEM_PROMPT),
new OpenAIChatCompletionStandardMessage("user", input)),
true,
CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY));
}
public static Request buildCustomOpenAICompletionRequest(String system, String context) {
return buildCustomOpenAIChatCompletionRequest(
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage("system", system),
new OpenAIChatCompletionStandardMessage("user", context)),
true,
CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY));
}
public static Request buildCustomOpenAICompletionRequest(String context, String url,
Map<String, String> headers, Map<String, Object> body, String credential) {
var usedSettings = new CustomServiceChatCompletionSettingsState();
usedSettings.setBody(body);
usedSettings.setHeaders(headers);
usedSettings.setUrl(url);
return buildCustomOpenAIChatCompletionRequest(
usedSettings,
List.of(new OpenAIChatCompletionStandardMessage("user", context)),
true,
credential);
}
public static Request buildCustomOpenAILookupCompletionRequest(String context) {
return buildCustomOpenAIChatCompletionRequest(
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage(
"system",
GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT),
new OpenAIChatCompletionStandardMessage("user", context)),
false,
CredentialsStore.getCredential(CUSTOM_SERVICE_API_KEY));
}
public static LlamaCompletionRequest buildLlamaLookupCompletionRequest(String context) {
return new LlamaCompletionRequest.Builder(
PromptTemplate.LLAMA.buildPrompt(GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT, context, List.of()))
.setStream(false)
.build();
}
public static LlamaCompletionRequest buildLlamaCompletionRequest(
Message message,
Conversation conversation,
ConversationType conversationType) {
var settings = LlamaSettings.getCurrentState();
PromptTemplate promptTemplate;
if (settings.isRunLocalServer()) {
promptTemplate = settings.isUseCustomModel()
? settings.getLocalModelPromptTemplate()
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate();
} else {
promptTemplate = settings.getRemoteModelPromptTemplate();
}
var systemPrompt = conversationType == FIX_COMPILE_ERRORS
? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : PersonaSettings.getSystemPrompt();
var prompt = promptTemplate.buildPrompt(
systemPrompt,
message.getPrompt(),
conversation.getMessages());
var configuration = ConfigurationSettings.getState();
return new LlamaCompletionRequest.Builder(prompt)
.setN_predict(configuration.getMaxTokens())
.setTemperature(configuration.getTemperature())
.setTop_k(settings.getTopK())
.setTop_p(settings.getTopP())
.setMin_p(settings.getMinP())
.setRepeat_penalty(settings.getRepeatPenalty())
.setStop(promptTemplate.getStopTokens())
.build();
}
public static LlamaCompletionRequest buildLlamaEditCodeRequest(String input) {
var settings = LlamaSettings.getCurrentState();
PromptTemplate promptTemplate;
if (settings.isRunLocalServer()) {
promptTemplate = settings.isUseCustomModel()
? settings.getLocalModelPromptTemplate()
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate();
} else {
promptTemplate = settings.getRemoteModelPromptTemplate();
}
var prompt = promptTemplate.buildPrompt(EDIT_CODE_SYSTEM_PROMPT, input, emptyList());
var configuration = ConfigurationSettings.getState();
return new LlamaCompletionRequest.Builder(prompt)
.setN_predict(configuration.getMaxTokens())
.setTemperature(configuration.getTemperature())
.setTop_k(settings.getTopK())
.setTop_p(settings.getTopP())
.setMin_p(settings.getMinP())
.setRepeat_penalty(settings.getRepeatPenalty())
.setStop(promptTemplate.getStopTokens())
.build();
}
public static OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest(
@Nullable String model,
CallParameters callParameters) {
var configuration = ConfigurationSettings.getState();
var requestBuilder = new OpenAIChatCompletionRequest.Builder(
buildOpenAIMessages(model, callParameters))
.setModel(model)
.setMaxTokens(configuration.getMaxTokens())
.setStream(true)
.setTemperature(configuration.getTemperature());
if (callParameters.getMessage().isWebSearchIncluded()) {
// tri-state boolean
requestBuilder.setWebSearchIncluded(true);
}
var documentationDetails =
callParameters.getMessage().getDocumentationDetails();
if (documentationDetails != null) {
var requestDocumentationDetails = new RequestDocumentationDetails();
requestDocumentationDetails.setName(documentationDetails.getName());
requestDocumentationDetails.setUrl(documentationDetails.getUrl());
requestBuilder.setDocumentationDetails(requestDocumentationDetails);
}
return requestBuilder.build();
}
public static GoogleCompletionRequest buildGoogleChatCompletionRequest(
@Nullable String model,
CallParameters callParameters) {
var configuration = ConfigurationSettings.getState();
return new GoogleCompletionRequest.Builder(buildGoogleMessages(model, callParameters))
.generationConfig(new GoogleGenerationConfig.Builder()
.maxOutputTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature()).build()).build();
}
public static GoogleCompletionRequest buildGoogleEditCodeRequest(String input) {
var configuration = ConfigurationSettings.getState();
return new GoogleCompletionRequest.Builder(List.of(
new GoogleCompletionContent("user", List.of(EDIT_CODE_SYSTEM_PROMPT)),
new GoogleCompletionContent("model", List.of("Understood.")),
new GoogleCompletionContent("user", List.of(input))))
.generationConfig(new GoogleGenerationConfig.Builder()
.maxOutputTokens(configuration.getMaxTokens())
.temperature(configuration.getTemperature()).build()).build();
}
public static ClaudeCompletionRequest buildAnthropicChatCompletionRequest(
CallParameters callParameters) {
var configuration = ConfigurationSettings.getState();
var settings = AnthropicSettings.getCurrentState();
@ -316,7 +354,7 @@ public class CompletionRequestProvider {
request.setMaxTokens(configuration.getMaxTokens());
request.setStream(true);
request.setSystem(PersonaSettings.getSystemPrompt());
List<ClaudeCompletionMessage> messages = conversation.getMessages().stream()
List<ClaudeCompletionMessage> messages = callParameters.getConversation().getMessages().stream()
.filter(prevMessage -> prevMessage.getResponse() != null
&& !prevMessage.getResponse().isEmpty())
.flatMap(prevMessage -> Stream.of(
@ -339,7 +377,20 @@ public class CompletionRequestProvider {
return request;
}
public OllamaChatCompletionRequest buildOllamaChatCompletionRequest(
public static ClaudeCompletionRequest buildAnthropicEditCodeRequest(
String input) {
var configuration = ConfigurationSettings.getState();
var settings = AnthropicSettings.getCurrentState();
var request = new ClaudeCompletionRequest();
request.setModel(settings.getModel());
request.setMaxTokens(configuration.getMaxTokens());
request.setStream(true);
request.setSystem(EDIT_CODE_SYSTEM_PROMPT);
request.setMessages(List.of(new ClaudeCompletionStandardMessage("user", input)));
return request;
}
public static OllamaChatCompletionRequest buildOllamaChatCompletionRequest(
CallParameters callParameters
) {
var configuration = ConfigurationSettings.getState();
@ -354,7 +405,24 @@ public class CompletionRequestProvider {
.build();
}
private List<OllamaChatCompletionMessage> buildOllamaMessages(CallParameters callParameters) {
public static OllamaChatCompletionRequest buildOllamaEditCodeRequest(String input) {
var configuration = ConfigurationSettings.getState();
var settings =
ApplicationManager.getApplication().getService(OllamaSettings.class).getState();
return new OllamaChatCompletionRequest
.Builder(settings.getModel(), List.of(
new OllamaChatCompletionMessage("system", EDIT_CODE_SYSTEM_PROMPT, null),
new OllamaChatCompletionMessage("user", input, null)))
.setStream(true)
.setOptions(new OllamaParameters.Builder()
.numPredict(configuration.getMaxTokens())
.temperature((double) configuration.getTemperature())
.build())
.build();
}
private static List<OllamaChatCompletionMessage> buildOllamaMessages(
CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OllamaChatCompletionMessage>();
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
@ -367,7 +435,7 @@ public class CompletionRequestProvider {
);
}
for (var prevMessage : conversation.getMessages()) {
for (var prevMessage : callParameters.getConversation().getMessages()) {
if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) {
break;
}
@ -406,7 +474,8 @@ public class CompletionRequestProvider {
return messages;
}
private List<OpenAIChatCompletionMessage> buildOpenAIMessages(CallParameters callParameters) {
private static List<OpenAIChatCompletionMessage> buildOpenAIMessages(
CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OpenAIChatCompletionMessage>();
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
@ -425,7 +494,7 @@ public class CompletionRequestProvider {
new OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT));
}
for (var prevMessage : conversation.getMessages()) {
for (var prevMessage : callParameters.getConversation().getMessages()) {
if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) {
break;
}
@ -463,7 +532,7 @@ public class CompletionRequestProvider {
return messages;
}
private List<OpenAIChatCompletionMessage> buildOpenAIMessages(
public static List<OpenAIChatCompletionMessage> buildOpenAIMessages(
@Nullable String model,
CallParameters callParameters) {
var messages = buildOpenAIMessages(callParameters);
@ -472,6 +541,7 @@ public class CompletionRequestProvider {
return messages;
}
var encodingManager = EncodingManager.getInstance();
int totalUsage = messages.parallelStream()
.mapToInt(encodingManager::countMessageTokens)
.sum() + ConfigurationSettings.getState().getMaxTokens();
@ -485,10 +555,14 @@ public class CompletionRequestProvider {
} catch (NoSuchElementException ex) {
return messages;
}
return tryReducingMessagesOrThrow(messages, totalUsage, modelMaxTokens);
return tryReducingMessagesOrThrow(
messages,
callParameters.getConversation().isDiscardTokenLimit(),
totalUsage,
modelMaxTokens);
}
private List<GoogleCompletionContent> buildGoogleMessages(CallParameters callParameters) {
private static List<GoogleCompletionContent> buildGoogleMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<GoogleCompletionContent>();
// Gemini API does not support direct 'system' prompts:
@ -503,7 +577,7 @@ public class CompletionRequestProvider {
messages.add(new GoogleCompletionContent("model", List.of("Understood.")));
}
for (var prevMessage : conversation.getMessages()) {
for (var prevMessage : callParameters.getConversation().getMessages()) {
if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) {
break;
}
@ -538,7 +612,7 @@ public class CompletionRequestProvider {
return messages;
}
private List<GoogleCompletionContent> buildGoogleMessages(
private static List<GoogleCompletionContent> buildGoogleMessages(
@Nullable String model,
CallParameters callParameters) {
var messages = buildGoogleMessages(callParameters);
@ -547,6 +621,7 @@ public class CompletionRequestProvider {
return messages;
}
var encodingManager = EncodingManager.getInstance();
int totalUsage = messages.parallelStream()
.mapToInt(message -> encodingManager.countMessageTokens(message.getRole(),
String.join(",", message.getParts().stream().map(GoogleContentPart::getText).toList())))
@ -561,19 +636,24 @@ public class CompletionRequestProvider {
} catch (NoSuchElementException ex) {
return messages;
}
return tryReducingGoogleMessagesOrThrow(messages, totalUsage, modelMaxTokens);
return tryReducingGoogleMessagesOrThrow(
messages,
callParameters.getConversation().isDiscardTokenLimit(),
totalUsage,
modelMaxTokens);
}
private List<OpenAIChatCompletionMessage> tryReducingMessagesOrThrow(
private static List<OpenAIChatCompletionMessage> tryReducingMessagesOrThrow(
List<OpenAIChatCompletionMessage> messages,
boolean discardTokenLimit,
int totalUsage,
int modelMaxTokens) {
if (!ConversationsState.getInstance().discardAllTokenLimits) {
if (!conversation.isDiscardTokenLimit()) {
if (!discardTokenLimit) {
throw new TotalUsageExceededException();
}
}
var encodingManager = EncodingManager.getInstance();
// skip the system prompt
for (int i = 1; i < messages.size(); i++) {
if (totalUsage <= modelMaxTokens) {
@ -590,16 +670,17 @@ public class CompletionRequestProvider {
return messages.stream().filter(Objects::nonNull).toList();
}
private List<GoogleCompletionContent> tryReducingGoogleMessagesOrThrow(
private static List<GoogleCompletionContent> tryReducingGoogleMessagesOrThrow(
List<GoogleCompletionContent> messages,
boolean discardTokenLimit,
int totalUsage,
int modelMaxTokens) {
if (!ConversationsState.getInstance().discardAllTokenLimits) {
if (!conversation.isDiscardTokenLimit()) {
if (!discardTokenLimit) {
throw new TotalUsageExceededException();
}
}
var encodingManager = EncodingManager.getInstance();
// skip the system prompt
for (int i = 1; i < messages.size(); i++) {
if (totalUsage <= modelMaxTokens) {

View file

@ -6,6 +6,7 @@ import com.intellij.openapi.diagnostic.Logger;
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory;
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestProvider;
import ee.carlrobert.codegpt.codecompletions.InfillRequest;
import ee.carlrobert.codegpt.actions.editor.EditCodeRequestParams;
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
import ee.carlrobert.codegpt.completions.llama.PromptTemplate;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
@ -16,7 +17,6 @@ import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettingsState;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
@ -83,10 +83,9 @@ public final class CompletionRequestService {
CallParameters callParameters,
CompletionEventListener<String> eventListener) {
var application = ApplicationManager.getApplication();
var requestProvider = new CompletionRequestProvider(callParameters.getConversation());
return switch (GeneralSettings.getSelectedService()) {
case CODEGPT -> CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync(
requestProvider.buildOpenAIChatCompletionRequest(
CompletionRequestProvider.buildOpenAIChatCompletionRequest(
application.getService(CodeGPTServiceSettings.class)
.getState()
.getChatCompletionSettings()
@ -95,35 +94,32 @@ public final class CompletionRequestService {
eventListener
);
case OPENAI -> CompletionClientProvider.getOpenAIClient().getChatCompletionAsync(
requestProvider.buildOpenAIChatCompletionRequest(
CompletionRequestProvider.buildOpenAIChatCompletionRequest(
OpenAISettings.getCurrentState().getModel(),
callParameters),
eventListener);
case CUSTOM_OPENAI -> getCustomOpenAIChatCompletionAsync(
requestProvider.buildCustomOpenAIChatCompletionRequest(
application.getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
callParameters),
CompletionRequestProvider.buildCustomOpenAIChatCompletionRequest(callParameters),
eventListener);
case ANTHROPIC -> CompletionClientProvider.getClaudeClient().getCompletionAsync(
requestProvider.buildAnthropicChatCompletionRequest(callParameters),
CompletionRequestProvider.buildAnthropicChatCompletionRequest(callParameters),
eventListener);
case AZURE -> CompletionClientProvider.getAzureClient().getChatCompletionAsync(
requestProvider.buildOpenAIChatCompletionRequest(null, callParameters),
CompletionRequestProvider.buildOpenAIChatCompletionRequest(null, callParameters),
eventListener);
case LLAMA_CPP -> CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
requestProvider.buildLlamaCompletionRequest(
CompletionRequestProvider.buildLlamaCompletionRequest(
callParameters.getMessage(),
callParameters.getConversation(),
callParameters.getConversationType()),
eventListener);
case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync(
requestProvider.buildOllamaChatCompletionRequest(callParameters),
CompletionRequestProvider.buildOllamaChatCompletionRequest(callParameters),
eventListener);
case GOOGLE -> {
var settings = application.getService(GoogleSettings.class).getState();
yield CompletionClientProvider.getGoogleClient().getChatCompletionAsync(
requestProvider.buildGoogleChatCompletionRequest(
CompletionRequestProvider.buildGoogleChatCompletionRequest(
settings.getModel(),
callParameters),
settings.getModel(),
@ -132,6 +128,54 @@ public final class CompletionRequestService {
};
}
public EventSource getEditCodeCompletionAsync(
EditCodeRequestParams params,
CompletionEventListener<String> eventListener) {
var input = "%s\n\n%s".formatted(params.getPrompt(), params.getSelectedText());
return switch (GeneralSettings.getSelectedService()) {
case CODEGPT -> CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync(
CompletionRequestProvider.buildEditCodeRequest(
input,
ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class)
.getState()
.getChatCompletionSettings()
.getModel()
),
eventListener
);
case OPENAI -> CompletionClientProvider.getOpenAIClient().getChatCompletionAsync(
CompletionRequestProvider.buildEditCodeRequest(
input,
OpenAISettings.getCurrentState().getModel()),
eventListener);
case CUSTOM_OPENAI -> getCustomOpenAIChatCompletionAsync(
CompletionRequestProvider.buildCustomOpenAIEditCodeRequest(input),
eventListener);
case ANTHROPIC -> CompletionClientProvider.getClaudeClient().getCompletionAsync(
CompletionRequestProvider.buildAnthropicEditCodeRequest(input),
eventListener);
case AZURE -> CompletionClientProvider.getAzureClient().getChatCompletionAsync(
CompletionRequestProvider.buildEditCodeRequest(input, null),
eventListener);
case LLAMA_CPP -> CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
CompletionRequestProvider.buildLlamaEditCodeRequest(input),
eventListener);
case OLLAMA -> CompletionClientProvider.getOllamaClient().getChatCompletionAsync(
CompletionRequestProvider.buildOllamaEditCodeRequest(input),
eventListener);
case GOOGLE -> {
var model =
ApplicationManager.getApplication().getService(GoogleSettings.class)
.getState()
.getModel();
yield CompletionClientProvider.getGoogleClient().getChatCompletionAsync(
CompletionRequestProvider.buildGoogleEditCodeRequest(input),
model,
eventListener);
}
};
}
public EventSource getCodeCompletionAsync(
InfillRequest requestDetails,
CompletionEventListener<String> eventListener) {

View file

@ -7,4 +7,5 @@ public enum ConversationType {
FIX_COMPILE_ERRORS,
MULTI_FILE,
INLINE_COMPLETION,
EDIT_CODE
}

View file

@ -172,11 +172,16 @@ public enum LlamaModel {
HuggingFaceModel.CODE_QWEN_1_5_7B_Q6_K)),
CODE_QWEN2_5_CODER(
"CodeQwen2.5 Coder", """
Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models (formerly known as CodeQwen).\
Qwen2.5-Coder is the latest series of Code-Specific Qwen large language models \
(formerly known as CodeQwen).
It brings the following improvements upon CodeQwen1.5:
- Significantly improvements in code generation, code reasoning and code fixing. Base on the strong Qwen2.5, we scale up the training tokens into 5.5 trillion including source code, text-code grounding, Synthetic data, etc.
- A more comprehensive foundation for real-world applications such as Code Agents. Not only enhancing coding capabilities but also maintaining its strengths in mathematics and general competencies.
- Significantly improvements in code generation, code reasoning and code fixing. \
Base on the strong Qwen2.5, we scale up the training tokens into 5.5 trillion including \
source code, text-code grounding, Synthetic data, etc.
- A more comprehensive foundation for real-world applications such as Code Agents. \
Not only enhancing coding capabilities but also maintaining its strengths in \
mathematics and general competencies.
- Long-context Support up to 128K tokens.
""",
PromptTemplate.CODE_QWEN,

View file

@ -9,12 +9,12 @@ import com.intellij.openapi.util.TextRange
import com.intellij.util.ui.AsyncProcessIcon
import com.intellij.openapi.util.text.StringUtil
import com.jetbrains.rd.util.AtomicReference
import ee.carlrobert.codegpt.completions.CompletionClientProvider
import ee.carlrobert.codegpt.completions.CompletionRequestProvider
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
import ee.carlrobert.codegpt.completions.CompletionRequestService
import ee.carlrobert.codegpt.ui.ObservableProperties
import javax.swing.JButton
data class EditCodeRequestParams(val prompt: String, val selectedText: String)
class EditCodeSubmissionHandler(
private val editor: Editor,
private val observableProperties: ObservableProperties,
@ -43,12 +43,8 @@ class EditCodeSubmissionHandler(
}
runInEdt { editor.selectionModel.removeSelection() }
// TODO: Support other providers
CompletionClientProvider.getCodeGPTClient().getChatCompletionAsync(
CompletionRequestProvider.buildEditCodeRequest(
"$userPrompt\n\n$selectedText",
service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
),
service<CompletionRequestService>().getEditCodeCompletionAsync(
EditCodeRequestParams(userPrompt, selectedText),
EditCodeCompletionListener(
editor,
selectionTextRange,

View file

@ -21,7 +21,6 @@ import com.intellij.util.ui.JBUI
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.actions.editor.EditCodeSubmissionHandler
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.service.ServiceType.CODEGPT
import ee.carlrobert.codegpt.toolwindow.chat.ui.textarea.ModelComboBoxAction
import ee.carlrobert.codegpt.util.ApplicationUtil
import kotlinx.coroutines.CoroutineScope
@ -195,8 +194,7 @@ class EditCodePopover(private val editor: Editor) {
ModelComboBoxAction(
ApplicationUtil.findCurrentProject(),
{},
GeneralSettings.getSelectedService(),
listOf(CODEGPT)
GeneralSettings.getSelectedService()
)
.createCustomComponent(ActionPlaces.UNKNOWN)
).align(AlignX.RIGHT)

View file

@ -4,10 +4,8 @@ import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.conversations.Conversation
import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.persona.DEFAULT_PROMPT
import ee.carlrobert.codegpt.settings.persona.PersonaSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.groups.Tuple
@ -24,8 +22,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
@ -54,8 +51,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
@ -84,8 +80,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
@ -115,8 +110,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(remainingMessage)
conversation.discardTokenLimits()
val request = CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
val request = CompletionRequestProvider.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
@ -142,8 +136,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
conversation.addMessage(createDummyMessage(1500))
assertThrows(TotalUsageExceededException::class.java) {
CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
CompletionRequestProvider.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,

View file

@ -47,6 +47,7 @@ interface ShortcutsTestMixin {
GeneralSettings.getCurrentState().selectedService = ServiceType.LLAMA_CPP
LlamaSettings.getCurrentState().serverPort = null
LlamaSettings.getCurrentState().isCodeCompletionsEnabled = codeCompletionsEnabled
LlamaSettings.getCurrentState().huggingFaceModel = HuggingFaceModel.CODE_LLAMA_7B_Q4
}
fun useOllamaService() {