fix: Use System Prompt from user configuration (#454) (#455)

This commit is contained in:
Rene Leonhardt 2024-04-15 10:42:42 +02:00 committed by GitHub
parent 0dfaa128b7
commit 5f16213bd1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 44 additions and 38 deletions

View file

@ -1,5 +1,7 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.ConversationType.DEFAULT;
import static ee.carlrobert.codegpt.completions.ConversationType.FIX_COMPILE_ERRORS;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY;
import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent;
import static java.lang.String.format;
@ -57,6 +59,7 @@ import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -74,6 +77,8 @@ public class CompletionRequestProvider {
public static final String FIX_COMPILE_ERRORS_SYSTEM_PROMPT = getResourceContent(
"/prompts/fix-compile-errors.txt");
private static final Set<ConversationType> OPENAI_SYSTEM_CONVERSATION_TYPES = Set.of(
DEFAULT, FIX_COMPILE_ERRORS);
private final EncodingManager encodingManager = EncodingManager.getInstance();
private final Conversation conversation;
@ -151,10 +156,8 @@ public class CompletionRequestProvider {
promptTemplate = settings.getRemoteModelPromptTemplate();
}
var systemPrompt = COMPLETION_SYSTEM_PROMPT;
if (conversationType == ConversationType.FIX_COMPILE_ERRORS) {
systemPrompt = FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
}
var systemPrompt = conversationType == FIX_COMPILE_ERRORS
? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : ConfigurationSettings.getSystemPrompt();
var prompt = promptTemplate.buildPrompt(
systemPrompt,
@ -257,7 +260,7 @@ public class CompletionRequestProvider {
request.setModel(settings.getModel());
request.setMaxTokens(configuration.getMaxTokens());
request.setStream(true);
request.setSystem(COMPLETION_SYSTEM_PROMPT);
request.setSystem(ConfigurationSettings.getSystemPrompt());
List<ClaudeCompletionMessage> messages = conversation.getMessages().stream()
.filter(prevMessage -> prevMessage.getResponse() != null
&& !prevMessage.getResponse().isEmpty())
@ -284,14 +287,10 @@ public class CompletionRequestProvider {
private List<OpenAIChatCompletionMessage> buildMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OpenAIChatCompletionMessage>();
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
messages.add(new OpenAIChatCompletionStandardMessage(
"system",
ConfigurationSettings.getCurrentState().getSystemPrompt()));
}
if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) {
messages.add(
new OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT));
if (OPENAI_SYSTEM_CONVERSATION_TYPES.contains(callParameters.getConversationType())) {
String content = DEFAULT == callParameters.getConversationType()
? ConfigurationSettings.getSystemPrompt() : FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
messages.add(new OpenAIChatCompletionStandardMessage("system", content));
}
for (var prevMessage : conversation.getMessages()) {

View file

@ -94,7 +94,7 @@ public class ConfigurationComponent {
maxTokensField.setValue(configuration.getMaxTokens());
systemPromptTextArea = new JTextArea();
if (configuration.getSystemPrompt().isEmpty()) {
if (configuration.getSystemPrompt().isBlank()) {
// for backward compatibility
systemPromptTextArea.setText(COMPLETION_SYSTEM_PROMPT);
} else {

View file

@ -31,4 +31,8 @@ public class ConfigurationSettings implements PersistentStateComponent<Configura
public static ConfigurationSettings getInstance() {
return ApplicationManager.getApplication().getService(ConfigurationSettings.class);
}
public static String getSystemPrompt() {
return getCurrentState().getSystemPrompt();
}
}

View file

@ -12,8 +12,7 @@ public class TotalTokensDetails {
private int referencedFilesTokens;
public TotalTokensDetails(EncodingManager encodingManager) {
systemPromptTokens = encodingManager.countTokens(
ConfigurationSettings.getCurrentState().getSystemPrompt());
systemPromptTokens = encodingManager.countTokens(ConfigurationSettings.getSystemPrompt());
}
public int getSystemPromptTokens() {