mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-10 12:10:14 +00:00
* Implement support for You Pro modes: Default, Agent, Custom with various 3rd party models and Research * Update list of You modes/models depending on user having subscription * add default value for chatMode
348 lines
14 KiB
Java
348 lines
14 KiB
Java
package ee.carlrobert.codegpt.completions;
|
|
|
|
import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent;
|
|
import static java.lang.String.format;
|
|
import static java.util.stream.Collectors.joining;
|
|
import static java.util.stream.Collectors.toList;
|
|
|
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
import com.intellij.openapi.application.ApplicationManager;
|
|
import com.intellij.openapi.diagnostic.Logger;
|
|
import ee.carlrobert.codegpt.CodeGPTPlugin;
|
|
import ee.carlrobert.codegpt.EncodingManager;
|
|
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
|
|
import ee.carlrobert.codegpt.completions.llama.PromptTemplate;
|
|
import ee.carlrobert.codegpt.conversations.Conversation;
|
|
import ee.carlrobert.codegpt.conversations.ConversationsState;
|
|
import ee.carlrobert.codegpt.conversations.message.Message;
|
|
import ee.carlrobert.codegpt.credentials.CustomServiceCredentialManager;
|
|
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
|
import ee.carlrobert.codegpt.settings.IncludedFilesSettings;
|
|
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
|
|
import ee.carlrobert.codegpt.settings.service.ServiceType;
|
|
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
|
|
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
|
|
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettingsState;
|
|
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
|
|
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
|
|
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
|
|
import ee.carlrobert.codegpt.telemetry.core.configuration.TelemetryConfiguration;
|
|
import ee.carlrobert.codegpt.telemetry.core.service.UserId;
|
|
import ee.carlrobert.embedding.EmbeddingsService;
|
|
import ee.carlrobert.embedding.ReferencedFile;
|
|
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
|
|
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequestMessage;
|
|
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
|
|
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
|
|
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionMessage;
|
|
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
|
|
import ee.carlrobert.llm.client.you.completion.YouCompletionRequest;
|
|
import ee.carlrobert.llm.client.you.completion.YouCompletionRequestMessage;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.util.ArrayList;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.NoSuchElementException;
|
|
import java.util.Objects;
|
|
import java.util.UUID;
|
|
import java.util.stream.Collectors;
|
|
import java.util.stream.Stream;
|
|
import okhttp3.Request;
|
|
import okhttp3.RequestBody;
|
|
import org.jetbrains.annotations.Nullable;
|
|
|
|
public class CompletionRequestProvider {
|
|
|
|
private static final Logger LOG = Logger.getInstance(CompletionRequestProvider.class);
|
|
|
|
public static final String COMPLETION_SYSTEM_PROMPT = getResourceContent(
|
|
"/prompts/default-completion-system-prompt.txt");
|
|
|
|
public static final String GENERATE_COMMIT_MESSAGE_SYSTEM_PROMPT = getResourceContent(
|
|
"/prompts/generate-commit-message-system-prompt.txt");
|
|
|
|
public static final String FIX_COMPILE_ERRORS_SYSTEM_PROMPT = getResourceContent(
|
|
"/prompts/fix-compile-errors.txt");
|
|
|
|
private final EncodingManager encodingManager = EncodingManager.getInstance();
|
|
private final EmbeddingsService embeddingsService;
|
|
private final Conversation conversation;
|
|
|
|
public CompletionRequestProvider(Conversation conversation) {
|
|
this.embeddingsService = new EmbeddingsService(
|
|
CompletionClientProvider.getOpenAIClient(),
|
|
CodeGPTPlugin.getPluginBasePath());
|
|
this.conversation = conversation;
|
|
}
|
|
|
|
public static String getPromptWithContext(List<ReferencedFile> referencedFiles,
|
|
String userPrompt) {
|
|
var includedFilesSettings = IncludedFilesSettings.getCurrentState();
|
|
var repeatableContext = referencedFiles.stream()
|
|
.map(item -> includedFilesSettings.getRepeatableContext()
|
|
.replace("{FILE_PATH}", item.getFilePath())
|
|
.replace("{FILE_CONTENT}", format(
|
|
"```%s\n%s\n```",
|
|
item.getFileExtension(),
|
|
item.getFileContent().trim())))
|
|
.collect(joining("\n\n"));
|
|
|
|
return includedFilesSettings.getPromptTemplate()
|
|
.replace("{REPEATABLE_CONTEXT}", repeatableContext)
|
|
.replace("{QUESTION}", userPrompt);
|
|
}
|
|
|
|
public static OpenAIChatCompletionRequest buildOpenAILookupCompletionRequest(String context) {
|
|
return new OpenAIChatCompletionRequest.Builder(
|
|
List.of(
|
|
new OpenAIChatCompletionMessage("system",
|
|
getResourceContent("/prompts/method-name-generator.txt")),
|
|
new OpenAIChatCompletionMessage("user", context)))
|
|
.setModel(OpenAISettings.getCurrentState().getModel())
|
|
.setStream(false)
|
|
.build();
|
|
}
|
|
|
|
public static Request buildCustomOpenAILookupCompletionRequest(String context) {
|
|
return buildCustomOpenAIChatCompletionRequest(
|
|
CustomServiceSettings.getCurrentState(),
|
|
List.of(
|
|
new OpenAIChatCompletionMessage(
|
|
"system",
|
|
getResourceContent("/prompts/method-name-generator.txt")),
|
|
new OpenAIChatCompletionMessage("user", context)),
|
|
false);
|
|
}
|
|
|
|
public static LlamaCompletionRequest buildLlamaLookupCompletionRequest(String context) {
|
|
return new LlamaCompletionRequest.Builder(PromptTemplate.LLAMA
|
|
.buildPrompt(getResourceContent("/prompts/method-name-generator.txt"), context, List.of()))
|
|
.setStream(false)
|
|
.build();
|
|
}
|
|
|
|
public LlamaCompletionRequest buildLlamaCompletionRequest(
|
|
Message message,
|
|
ConversationType conversationType) {
|
|
var settings = LlamaSettings.getCurrentState();
|
|
PromptTemplate promptTemplate;
|
|
if (settings.isRunLocalServer()) {
|
|
promptTemplate = settings.isUseCustomModel()
|
|
? settings.getLocalModelPromptTemplate()
|
|
: LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate();
|
|
} else {
|
|
promptTemplate = settings.getRemoteModelPromptTemplate();
|
|
}
|
|
|
|
var systemPrompt = COMPLETION_SYSTEM_PROMPT;
|
|
if (conversationType == ConversationType.FIX_COMPILE_ERRORS) {
|
|
systemPrompt = FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
|
|
}
|
|
|
|
var prompt = promptTemplate.buildPrompt(
|
|
systemPrompt,
|
|
message.getPrompt(),
|
|
conversation.getMessages());
|
|
var configuration = ConfigurationSettings.getCurrentState();
|
|
return new LlamaCompletionRequest.Builder(prompt)
|
|
.setN_predict(configuration.getMaxTokens())
|
|
.setTemperature(configuration.getTemperature())
|
|
.setTop_k(settings.getTopK())
|
|
.setTop_p(settings.getTopP())
|
|
.setMin_p(settings.getMinP())
|
|
.setRepeat_penalty(settings.getRepeatPenalty())
|
|
.build();
|
|
}
|
|
|
|
public YouCompletionRequest buildYouCompletionRequest(Message message) {
|
|
var requestBuilder = new YouCompletionRequest.Builder(message.getPrompt())
|
|
.setUseGPT4Model(YouSettings.getCurrentState().isUseGPT4Model())
|
|
.setChatMode(YouSettings.getCurrentState().getChatMode())
|
|
.setCustomModel(YouSettings.getCurrentState().getCustomModel())
|
|
.setChatHistory(conversation.getMessages().stream()
|
|
.map(prevMessage -> new YouCompletionRequestMessage(
|
|
prevMessage.getPrompt(),
|
|
prevMessage.getResponse()))
|
|
.collect(toList()));
|
|
if (TelemetryConfiguration.getInstance().isEnabled()
|
|
&& !ApplicationManager.getApplication().isUnitTestMode()) {
|
|
requestBuilder.setUserId(UUID.fromString(UserId.INSTANCE.get()));
|
|
}
|
|
return requestBuilder.build();
|
|
}
|
|
|
|
public OpenAIChatCompletionRequest buildOpenAIChatCompletionRequest(
|
|
@Nullable String model,
|
|
CallParameters callParameters,
|
|
boolean useContextualSearch,
|
|
@Nullable String overriddenPath) {
|
|
var configuration = ConfigurationSettings.getCurrentState();
|
|
var builder = new OpenAIChatCompletionRequest.Builder(
|
|
buildMessages(model, callParameters, useContextualSearch))
|
|
.setModel(model)
|
|
.setMaxTokens(configuration.getMaxTokens())
|
|
.setStream(true)
|
|
.setTemperature(configuration.getTemperature());
|
|
|
|
if (overriddenPath != null) {
|
|
builder.setOverriddenPath(overriddenPath);
|
|
}
|
|
|
|
return builder.build();
|
|
}
|
|
|
|
public Request buildCustomOpenAIChatCompletionRequest(
|
|
CustomServiceSettingsState customConfiguration,
|
|
CallParameters callParameters) {
|
|
return buildCustomOpenAIChatCompletionRequest(
|
|
customConfiguration,
|
|
buildMessages(callParameters, false),
|
|
true);
|
|
}
|
|
|
|
private static Request buildCustomOpenAIChatCompletionRequest(
|
|
CustomServiceSettingsState customConfiguration,
|
|
List<OpenAIChatCompletionMessage> messages,
|
|
boolean streamRequest) {
|
|
var requestBuilder = new Request.Builder().url(customConfiguration.getUrl().trim());
|
|
for (var entry : customConfiguration.getHeaders().entrySet()) {
|
|
String value = entry.getValue();
|
|
if (value.contains("$CUSTOM_SERVICE_API_KEY")) {
|
|
value = value.replace("$CUSTOM_SERVICE_API_KEY",
|
|
CustomServiceCredentialManager.getInstance().getCredential());
|
|
}
|
|
requestBuilder.addHeader(entry.getKey(), value);
|
|
}
|
|
|
|
var body = customConfiguration.getBody().entrySet().stream()
|
|
.collect(Collectors.toMap(
|
|
Map.Entry::getKey,
|
|
entry -> {
|
|
if (!streamRequest && "stream".equals(entry.getKey())) {
|
|
return false;
|
|
}
|
|
|
|
var value = entry.getValue();
|
|
if (value instanceof String && "$OPENAI_MESSAGES".equals(((String) value).trim())) {
|
|
return messages;
|
|
}
|
|
return value;
|
|
}
|
|
));
|
|
|
|
try {
|
|
var requestBody = RequestBody.create(new ObjectMapper()
|
|
.writerWithDefaultPrettyPrinter()
|
|
.writeValueAsString(body)
|
|
.getBytes(StandardCharsets.UTF_8));
|
|
return requestBuilder.post(requestBody).build();
|
|
} catch (JsonProcessingException e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
|
|
public ClaudeCompletionRequest buildAnthropicChatCompletionRequest(
|
|
CallParameters callParameters) {
|
|
var configuration = ConfigurationSettings.getCurrentState();
|
|
var settings = AnthropicSettings.getCurrentState();
|
|
var request = new ClaudeCompletionRequest();
|
|
request.setModel(settings.getModel());
|
|
request.setMaxTokens(configuration.getMaxTokens());
|
|
request.setStream(true);
|
|
request.setSystem(COMPLETION_SYSTEM_PROMPT);
|
|
var messages = conversation.getMessages().stream()
|
|
.filter(prevMessage -> prevMessage.getResponse() != null
|
|
&& !prevMessage.getResponse().isEmpty())
|
|
.flatMap(prevMessage -> Stream.of(
|
|
new ClaudeCompletionRequestMessage("user", prevMessage.getPrompt()),
|
|
new ClaudeCompletionRequestMessage("assistant", prevMessage.getResponse())))
|
|
.collect(toList());
|
|
messages.add(
|
|
new ClaudeCompletionRequestMessage("user", callParameters.getMessage().getPrompt()));
|
|
request.setMessages(messages);
|
|
return request;
|
|
}
|
|
|
|
private List<OpenAIChatCompletionMessage> buildMessages(
|
|
CallParameters callParameters,
|
|
boolean useContextualSearch) {
|
|
var message = callParameters.getMessage();
|
|
var messages = new ArrayList<OpenAIChatCompletionMessage>();
|
|
if (useContextualSearch) {
|
|
var prompt = embeddingsService.buildPromptWithContext(
|
|
message.getPrompt());
|
|
LOG.info("Retrieved context:\n" + prompt);
|
|
messages.add(new OpenAIChatCompletionMessage("user", prompt));
|
|
} else {
|
|
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
|
|
messages.add(new OpenAIChatCompletionMessage(
|
|
"system",
|
|
ConfigurationSettings.getCurrentState().getSystemPrompt()));
|
|
}
|
|
if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) {
|
|
messages.add(new OpenAIChatCompletionMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT));
|
|
}
|
|
|
|
for (var prevMessage : conversation.getMessages()) {
|
|
if (callParameters.isRetry() && prevMessage.getId().equals(message.getId())) {
|
|
break;
|
|
}
|
|
messages.add(new OpenAIChatCompletionMessage("user", prevMessage.getPrompt()));
|
|
messages.add(new OpenAIChatCompletionMessage("assistant", prevMessage.getResponse()));
|
|
}
|
|
messages.add(new OpenAIChatCompletionMessage("user", message.getPrompt()));
|
|
}
|
|
return messages;
|
|
}
|
|
|
|
private List<OpenAIChatCompletionMessage> buildMessages(
|
|
@Nullable String model,
|
|
CallParameters callParameters,
|
|
boolean useContextualSearch) {
|
|
var messages = buildMessages(callParameters, useContextualSearch);
|
|
|
|
if (model == null
|
|
|| GeneralSettings.getCurrentState().getSelectedService() == ServiceType.YOU) {
|
|
return messages;
|
|
}
|
|
|
|
int totalUsage = messages.parallelStream()
|
|
.mapToInt(encodingManager::countMessageTokens)
|
|
.sum() + ConfigurationSettings.getCurrentState().getMaxTokens();
|
|
int modelMaxTokens;
|
|
try {
|
|
modelMaxTokens = OpenAIChatCompletionModel.findByCode(model).getMaxTokens();
|
|
|
|
if (totalUsage <= modelMaxTokens) {
|
|
return messages;
|
|
}
|
|
} catch (NoSuchElementException ex) {
|
|
return messages;
|
|
}
|
|
return tryReducingMessagesOrThrow(messages, totalUsage, modelMaxTokens);
|
|
}
|
|
|
|
private List<OpenAIChatCompletionMessage> tryReducingMessagesOrThrow(
|
|
List<OpenAIChatCompletionMessage> messages,
|
|
int totalUsage,
|
|
int modelMaxTokens) {
|
|
if (!ConversationsState.getInstance().discardAllTokenLimits) {
|
|
if (!conversation.isDiscardTokenLimit()) {
|
|
throw new TotalUsageExceededException();
|
|
}
|
|
}
|
|
|
|
// skip the system prompt
|
|
for (int i = 1; i < messages.size(); i++) {
|
|
if (totalUsage <= modelMaxTokens) {
|
|
break;
|
|
}
|
|
|
|
totalUsage -= encodingManager.countMessageTokens(messages.get(i));
|
|
messages.set(i, null);
|
|
}
|
|
|
|
return messages.stream().filter(Objects::nonNull).collect(toList());
|
|
}
|
|
}
|