fix: Use System Prompt from user configuration (#454) (#455)

This commit is contained in:
Rene Leonhardt 2024-04-15 10:42:42 +02:00 committed by Carl-Robert Linnupuu
parent 763c65e1f6
commit 3a9e212582
7 changed files with 44 additions and 38 deletions

View file

@ -1,5 +1,7 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.ConversationType.DEFAULT;
import static ee.carlrobert.codegpt.completions.ConversationType.FIX_COMPILE_ERRORS;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY;
import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent;
import static java.lang.String.format;
@ -57,6 +59,7 @@ import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -74,6 +77,8 @@ public class CompletionRequestProvider {
public static final String FIX_COMPILE_ERRORS_SYSTEM_PROMPT = getResourceContent(
"/prompts/fix-compile-errors.txt");
private static final Set<ConversationType> OPENAI_SYSTEM_CONVERSATION_TYPES = Set.of(
DEFAULT, FIX_COMPILE_ERRORS);
private final EncodingManager encodingManager = EncodingManager.getInstance();
private final Conversation conversation;
@ -151,10 +156,8 @@ public class CompletionRequestProvider {
promptTemplate = settings.getRemoteModelPromptTemplate();
}
var systemPrompt = COMPLETION_SYSTEM_PROMPT;
if (conversationType == ConversationType.FIX_COMPILE_ERRORS) {
systemPrompt = FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
}
var systemPrompt = conversationType == FIX_COMPILE_ERRORS
? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : ConfigurationSettings.getSystemPrompt();
var prompt = promptTemplate.buildPrompt(
systemPrompt,
@ -257,7 +260,7 @@ public class CompletionRequestProvider {
request.setModel(settings.getModel());
request.setMaxTokens(configuration.getMaxTokens());
request.setStream(true);
request.setSystem(COMPLETION_SYSTEM_PROMPT);
request.setSystem(ConfigurationSettings.getSystemPrompt());
List<ClaudeCompletionMessage> messages = conversation.getMessages().stream()
.filter(prevMessage -> prevMessage.getResponse() != null
&& !prevMessage.getResponse().isEmpty())
@ -284,14 +287,10 @@ public class CompletionRequestProvider {
private List<OpenAIChatCompletionMessage> buildMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OpenAIChatCompletionMessage>();
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
messages.add(new OpenAIChatCompletionStandardMessage(
"system",
ConfigurationSettings.getCurrentState().getSystemPrompt()));
}
if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) {
messages.add(
new OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT));
if (OPENAI_SYSTEM_CONVERSATION_TYPES.contains(callParameters.getConversationType())) {
String content = DEFAULT == callParameters.getConversationType()
? ConfigurationSettings.getSystemPrompt() : FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
messages.add(new OpenAIChatCompletionStandardMessage("system", content));
}
for (var prevMessage : conversation.getMessages()) {

View file

@ -94,7 +94,7 @@ public class ConfigurationComponent {
maxTokensField.setValue(configuration.getMaxTokens());
systemPromptTextArea = new JTextArea();
if (configuration.getSystemPrompt().isEmpty()) {
if (configuration.getSystemPrompt().isBlank()) {
// for backward compatibility
systemPromptTextArea.setText(COMPLETION_SYSTEM_PROMPT);
} else {

View file

@ -31,4 +31,8 @@ public class ConfigurationSettings implements PersistentStateComponent<Configura
public static ConfigurationSettings getInstance() {
return ApplicationManager.getApplication().getService(ConfigurationSettings.class);
}
public static String getSystemPrompt() {
return getCurrentState().getSystemPrompt();
}
}

View file

@ -12,8 +12,7 @@ public class TotalTokensDetails {
private int referencedFilesTokens;
public TotalTokensDetails(EncodingManager encodingManager) {
systemPromptTokens = encodingManager.countTokens(
ConfigurationSettings.getCurrentState().getSystemPrompt());
systemPromptTokens = encodingManager.countTokens(ConfigurationSettings.getSystemPrompt());
}
public int getSystemPromptTokens() {

View file

@ -1,5 +1,6 @@
package ee.carlrobert.codegpt.completions
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
@ -42,6 +43,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
}
fun testChatCompletionRequestWithoutSystemPromptOverride() {
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
val conversation = ConversationService.getInstance().startConversation()
val firstMessage = createDummyMessage(500)
val secondMessage = createDummyMessage(250)
@ -60,7 +62,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
assertThat(request.messages)
.extracting("role", "content")
.containsExactly(
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
Tuple.tuple("system", COMPLETION_SYSTEM_PROMPT),
Tuple.tuple("user", "TEST_PROMPT"),
Tuple.tuple("assistant", firstMessage.response),
Tuple.tuple("user", "TEST_PROMPT"),
@ -69,7 +71,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
}
fun testChatCompletionRequestRetry() {
ConfigurationSettings.getCurrentState().systemPrompt = CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val conversation = ConversationService.getInstance().startConversation()
val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500)
val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250)
@ -88,13 +90,14 @@ class CompletionRequestProviderTest : IntegrationTest() {
assertThat(request.messages)
.extracting("role", "content")
.containsExactly(
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
Tuple.tuple("system", "TEST_SYSTEM_PROMPT"),
Tuple.tuple("user", "FIRST_TEST_PROMPT"),
Tuple.tuple("assistant", firstMessage.response),
Tuple.tuple("user", "SECOND_TEST_PROMPT"))
}
fun testReducedChatCompletionRequest() {
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(createDummyMessage(50))
conversation.addMessage(createDummyMessage(100))
@ -116,7 +119,7 @@ class CompletionRequestProviderTest : IntegrationTest() {
assertThat(request.messages)
.extracting("role", "content")
.containsExactly(
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
Tuple.tuple("system", COMPLETION_SYSTEM_PROMPT),
Tuple.tuple("user", "TEST_PROMPT"),
Tuple.tuple("assistant", remainingMessage.response),
Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT"))

View file

@ -1,7 +1,6 @@
package ee.carlrobert.codegpt.completions
import ee.carlrobert.codegpt.CodeGPTPlugin
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
@ -20,6 +19,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
fun testOpenAIChatCompletionCall() {
useOpenAIService()
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
@ -34,7 +34,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PROMPT")))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
@ -50,6 +50,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
fun testAzureChatCompletionCall() {
useAzureService()
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val conversationService = ConversationService.getInstance()
val prevMessage = Message("TEST_PREV_PROMPT")
prevMessage.response = "TEST_PREV_RESPONSE"
@ -66,7 +67,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
.extracting("messages")
.isEqualTo(
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"),
mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"),
mapOf("role" to "user", "content" to "TEST_PROMPT")))
@ -138,6 +139,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
fun testLlamaChatCompletionCall() {
useLlamaService()
ConfigurationSettings.getCurrentState().maxTokens = 99
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(Message("Ping", "Pong"))
@ -151,7 +153,7 @@ class DefaultCompletionRequestHandlerTest : IntegrationTest() {
"stream")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_SYSTEM_PROMPT",
"TEST_PROMPT",
conversation.messages),
99,

View file

@ -3,7 +3,6 @@ package ee.carlrobert.codegpt.toolwindow.chat
import ee.carlrobert.codegpt.CodeGPTKeys
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.ReferencedFile
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.completions.HuggingFaceModel
@ -31,7 +30,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
fun testSendingOpenAIMessage() {
useOpenAIService()
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val message = Message("Hello!")
val conversation = ConversationService.getInstance().startConversation()
val panel = ChatToolWindowTabPanel(project, conversation)
@ -46,7 +45,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to "Hello!")))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
@ -68,7 +67,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
encodingManager.countTokens(message.prompt),
0,
0)
@ -93,7 +92,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")))
useOpenAIService()
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_MESSAGE")
message.userMessage = "TEST_MESSAGE"
message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")
@ -110,7 +109,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to """
Use the following context to answer question at the end:
@ -153,7 +152,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
encodingManager.countTokens(message.prompt),
0,
0)
@ -180,7 +179,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
val testImagePath = Objects.requireNonNull(javaClass.getResource("/images/test-image.png")).path
project.putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath)
useOpenAIService("gpt-4-vision-preview")
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_MESSAGE")
val conversation = ConversationService.getInstance().startConversation()
val panel = ChatToolWindowTabPanel(project, conversation)
@ -196,7 +195,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
.containsExactly(
"gpt-4-vision-preview",
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "system", "content" to "TEST_SYSTEM_PROMPT"),
mapOf("role" to "user", "content" to listOf(
mapOf(
"type" to "image_url",
@ -226,7 +225,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
encodingManager.countTokens(message.prompt),
0,
0)
@ -256,7 +255,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")))
useOpenAIService()
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val message = Message("TEST_MESSAGE")
message.userMessage = "TEST_MESSAGE"
message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")
@ -316,7 +315,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens("TEST_SYSTEM_PROMPT"),
encodingManager.countTokens(message.prompt),
0,
0)
@ -342,7 +341,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
fun testSendingLlamaMessage() {
useLlamaService()
val configurationState = ConfigurationSettings.getCurrentState()
configurationState.systemPrompt = COMPLETION_SYSTEM_PROMPT
configurationState.systemPrompt = "TEST_SYSTEM_PROMPT"
configurationState.maxTokens = 1000
configurationState.temperature = 0.1
val llamaSettings = LlamaSettings.getCurrentState()
@ -369,7 +368,7 @@ class ChatToolWindowTabPanelTest : IntegrationTest() {
"repeat_penalty")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_SYSTEM_PROMPT",
"TEST_PROMPT",
conversation.messages),
configurationState.maxTokens,