mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-13 23:53:02 +00:00
chore: Convert Java tests to Kotlin (#447)
This commit is contained in:
parent
6fb0b8d30c
commit
0cdd5096ba
21 changed files with 1276 additions and 1271 deletions
|
|
@ -1,43 +0,0 @@
|
|||
package ee.carlrobert.codegpt.codecompletions;
|
||||
|
||||
import static ee.carlrobert.codegpt.CodeGPTKeys.PREVIOUS_INLAY_TEXT;
|
||||
import static ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate.LLAMA;
|
||||
import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.e;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import com.intellij.openapi.editor.VisualPosition;
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
|
||||
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
|
||||
import java.util.List;
|
||||
import testsupport.IntegrationTest;
|
||||
|
||||
public class CodeCompletionServiceTest extends IntegrationTest {
|
||||
|
||||
private final VisualPosition cursorPosition = new VisualPosition(3, 0);
|
||||
|
||||
public void testFetchCodeCompletionLlama() {
|
||||
useLlamaService();
|
||||
LlamaSettings.getCurrentState().setCodeCompletionsEnabled(true);
|
||||
myFixture.configureByText(
|
||||
"CompletionTest.java",
|
||||
getResourceContent("/codecompletions/code-completion-file.txt"));
|
||||
myFixture.getEditor().getCaretModel().moveToVisualPosition(cursorPosition);
|
||||
var expectedCompletion = "TEST_OUTPUT";
|
||||
var prefix = "z".repeat(245) + "\n[INPUT]\nc"; // 128 tokens
|
||||
var suffix = "\n[\\INPUT]\n" + "z".repeat(247); // 128 tokens
|
||||
expectLlama((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/completion");
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getBody())
|
||||
.extracting("prompt")
|
||||
.isEqualTo(LLAMA.buildPrompt(prefix, suffix));
|
||||
return List.of(jsonMapResponse(e("content", expectedCompletion), e("stop", true)));
|
||||
});
|
||||
|
||||
myFixture.type('c');
|
||||
|
||||
waitExpecting(() -> "TEST_OUTPUT".equals(PREVIOUS_INLAY_TEXT.get(myFixture.getEditor())));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
package ee.carlrobert.codegpt.completions;
|
||||
|
||||
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.groups.Tuple.tuple;
|
||||
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService;
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore;
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey;
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
|
||||
import testsupport.IntegrationTest;
|
||||
|
||||
public class CompletionRequestProviderTest extends IntegrationTest {
|
||||
|
||||
public void testChatCompletionRequestWithSystemPromptOverride() {
|
||||
CredentialsStore.INSTANCE.setCredential(CredentialKey.OPENAI_API_KEY, "TEST_API_KEY");
|
||||
ConfigurationSettings.getCurrentState().setSystemPrompt("TEST_SYSTEM_PROMPT");
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var firstMessage = createDummyMessage(500);
|
||||
var secondMessage = createDummyMessage(250);
|
||||
conversation.addMessage(firstMessage);
|
||||
conversation.addMessage(secondMessage);
|
||||
|
||||
var request = new CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.getCode(),
|
||||
new CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
new Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
false));
|
||||
|
||||
assertThat(request.getMessages())
|
||||
.extracting("role", "content")
|
||||
.containsExactly(
|
||||
tuple("system", "TEST_SYSTEM_PROMPT"),
|
||||
tuple("user", "TEST_PROMPT"),
|
||||
tuple("assistant", firstMessage.getResponse()),
|
||||
tuple("user", "TEST_PROMPT"),
|
||||
tuple("assistant", secondMessage.getResponse()),
|
||||
tuple("user", "TEST_CHAT_COMPLETION_PROMPT"));
|
||||
}
|
||||
|
||||
public void testChatCompletionRequestWithoutSystemPromptOverride() {
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var firstMessage = createDummyMessage(500);
|
||||
var secondMessage = createDummyMessage(250);
|
||||
conversation.addMessage(firstMessage);
|
||||
conversation.addMessage(secondMessage);
|
||||
|
||||
var request = new CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.getCode(),
|
||||
new CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
new Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
false));
|
||||
|
||||
assertThat(request.getMessages())
|
||||
.extracting("role", "content")
|
||||
.containsExactly(
|
||||
tuple("system", COMPLETION_SYSTEM_PROMPT),
|
||||
tuple("user", "TEST_PROMPT"),
|
||||
tuple("assistant", firstMessage.getResponse()),
|
||||
tuple("user", "TEST_PROMPT"),
|
||||
tuple("assistant", secondMessage.getResponse()),
|
||||
tuple("user", "TEST_CHAT_COMPLETION_PROMPT"));
|
||||
}
|
||||
|
||||
public void testChatCompletionRequestRetry() {
|
||||
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500);
|
||||
var secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250);
|
||||
conversation.addMessage(firstMessage);
|
||||
conversation.addMessage(secondMessage);
|
||||
|
||||
var request = new CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.getCode(),
|
||||
new CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
secondMessage,
|
||||
true));
|
||||
|
||||
assertThat(request.getMessages())
|
||||
.extracting("role", "content")
|
||||
.containsExactly(
|
||||
tuple("system", COMPLETION_SYSTEM_PROMPT),
|
||||
tuple("user", "FIRST_TEST_PROMPT"),
|
||||
tuple("assistant", firstMessage.getResponse()),
|
||||
tuple("user", "SECOND_TEST_PROMPT"));
|
||||
}
|
||||
|
||||
public void testReducedChatCompletionRequest() {
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
conversation.addMessage(createDummyMessage(50));
|
||||
conversation.addMessage(createDummyMessage(100));
|
||||
conversation.addMessage(createDummyMessage(150));
|
||||
conversation.addMessage(createDummyMessage(1000));
|
||||
var remainingMessage = createDummyMessage(2000);
|
||||
conversation.addMessage(remainingMessage);
|
||||
conversation.discardTokenLimits();
|
||||
|
||||
var request = new CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.getCode(),
|
||||
new CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
new Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
false));
|
||||
|
||||
assertThat(request.getMessages())
|
||||
.extracting("role", "content")
|
||||
.containsExactly(
|
||||
tuple("system", COMPLETION_SYSTEM_PROMPT),
|
||||
tuple("user", "TEST_PROMPT"),
|
||||
tuple("assistant", remainingMessage.getResponse()),
|
||||
tuple("user", "TEST_CHAT_COMPLETION_PROMPT"));
|
||||
}
|
||||
|
||||
public void testTotalUsageExceededException() {
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
conversation.addMessage(createDummyMessage(1500));
|
||||
conversation.addMessage(createDummyMessage(1500));
|
||||
conversation.addMessage(createDummyMessage(1500));
|
||||
|
||||
assertThrows(TotalUsageExceededException.class,
|
||||
() -> new CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.getCode(),
|
||||
new CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
createDummyMessage(100),
|
||||
false)));
|
||||
}
|
||||
|
||||
private Message createDummyMessage(int tokenSize) {
|
||||
return createDummyMessage("TEST_PROMPT", tokenSize);
|
||||
}
|
||||
|
||||
private Message createDummyMessage(String prompt, int tokenSize) {
|
||||
var message = new Message(prompt);
|
||||
// 'zz' = 1 token, prompt = 6 tokens, 7 tokens per message (GPT-3),
|
||||
message.setResponse("zz".repeat((tokenSize) - 6 - 7));
|
||||
return message;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,183 +0,0 @@
|
|||
package ee.carlrobert.codegpt.completions;
|
||||
|
||||
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
|
||||
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.e;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
|
||||
import static org.apache.http.HttpHeaders.AUTHORIZATION;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import ee.carlrobert.codegpt.CodeGPTPlugin;
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService;
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
|
||||
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import testsupport.IntegrationTest;
|
||||
|
||||
public class DefaultCompletionRequestHandlerTest extends IntegrationTest {
|
||||
|
||||
public void testOpenAIChatCompletionCall() {
|
||||
useOpenAIService();
|
||||
var message = new Message("TEST_PROMPT");
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var requestHandler = new CompletionRequestHandler(getRequestEventListener(message));
|
||||
expectOpenAI((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
|
||||
assertThat(request.getBody())
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
List.of(
|
||||
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
|
||||
Map.of("role", "user", "content", "TEST_PROMPT")));
|
||||
|
||||
return List.of(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
|
||||
});
|
||||
|
||||
requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false));
|
||||
|
||||
waitExpecting(() -> "Hello!".equals(message.getResponse()));
|
||||
}
|
||||
|
||||
public void testAzureChatCompletionCall() {
|
||||
useAzureService();
|
||||
var conversationService = ConversationService.getInstance();
|
||||
var prevMessage = new Message("TEST_PREV_PROMPT");
|
||||
prevMessage.setResponse("TEST_PREV_RESPONSE");
|
||||
var conversation = conversationService.startConversation();
|
||||
conversation.addMessage(prevMessage);
|
||||
conversationService.saveConversation(conversation);
|
||||
expectAzure((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo(
|
||||
"/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions");
|
||||
assertThat(request.getUri().getQuery()).isEqualTo("api-version=TEST_API_VERSION");
|
||||
assertThat(request.getHeaders().get("Api-key").get(0)).isEqualTo("TEST_API_KEY");
|
||||
assertThat(request.getHeaders().get("X-llm-application-tag").get(0)).isEqualTo("codegpt");
|
||||
assertThat(request.getBody())
|
||||
.extracting("messages")
|
||||
.isEqualTo(
|
||||
List.of(
|
||||
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
|
||||
Map.of("role", "user", "content", "TEST_PREV_PROMPT"),
|
||||
Map.of("role", "assistant", "content", "TEST_PREV_RESPONSE"),
|
||||
Map.of("role", "user", "content", "TEST_PROMPT")));
|
||||
return List.of(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
|
||||
});
|
||||
var message = new Message("TEST_PROMPT");
|
||||
var requestHandler = new CompletionRequestHandler(getRequestEventListener(message));
|
||||
|
||||
requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false));
|
||||
|
||||
waitExpecting(() -> "Hello!".equals(message.getResponse()));
|
||||
}
|
||||
|
||||
public void testYouChatCompletionCall() {
|
||||
useYouService();
|
||||
var message = new Message("TEST_PROMPT");
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
conversation.addMessage(new Message("Ping", "Pong"));
|
||||
var requestHandler = new CompletionRequestHandler(getRequestEventListener(message));
|
||||
expectYou((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch");
|
||||
assertThat(request.getMethod()).isEqualTo("GET");
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch");
|
||||
assertThat(request.getUri().getQuery()).isEqualTo(
|
||||
"q=TEST_PROMPT&"
|
||||
+ "page=1&"
|
||||
+ "cfr=CodeGPT&"
|
||||
+ "count=10&"
|
||||
+ "safeSearch=WebPages,Translations,TimeZone,Computation,RelatedSearches&"
|
||||
+ "domain=youchat&"
|
||||
+ "selectedChatMode=default&"
|
||||
+ "chat=[{\"question\":\"Ping\",\"answer\":\"Pong\"}]&"
|
||||
+ "utm_source=ide&"
|
||||
+ "utm_medium=jetbrains&"
|
||||
+ "utm_campaign=" + CodeGPTPlugin.getVersion() + "&"
|
||||
+ "utm_content=CodeGPT");
|
||||
assertThat(request.getHeaders())
|
||||
.flatExtracting("Accept", "Connection", "User-agent", "Cookie")
|
||||
.containsExactly(
|
||||
"text/event-stream",
|
||||
"Keep-Alive",
|
||||
"youide CodeGPT",
|
||||
"safesearch_guest=Moderate; "
|
||||
+ "youpro_subscription=true; "
|
||||
+ "you_subscription=free; "
|
||||
+ "stytch_session=; "
|
||||
+ "ydc_stytch_session=; "
|
||||
+ "stytch_session_jwt=; "
|
||||
+ "ydc_stytch_session_jwt=; "
|
||||
+ "eg4=false; "
|
||||
+ "__cf_bm=aN2b3pQMH8XADeMB7bg9s1bJ_bfXBcCHophfOGRg6g0-1693601599-0-"
|
||||
+ "AWIt5Mr4Y3xQI4mIJ1lSf4+vijWKDobrty8OopDeBxY+NABe0MRFidF3dCUoWjRt8"
|
||||
+ "SVMvBZPI3zkOgcRs7Mz3yazd7f7c58HwW5Xg9jdBjNg;");
|
||||
return List.of(
|
||||
jsonMapResponse("youChatToken", "Hel"),
|
||||
jsonMapResponse("youChatToken", "lo"),
|
||||
jsonMapResponse("youChatToken", "!"));
|
||||
});
|
||||
|
||||
requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false));
|
||||
|
||||
waitExpecting(() -> "Hello!".equals(message.getResponse()));
|
||||
}
|
||||
|
||||
public void testLlamaChatCompletionCall() {
|
||||
useLlamaService();
|
||||
ConfigurationSettings.getCurrentState().setMaxTokens(99);
|
||||
var message = new Message("TEST_PROMPT");
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
conversation.addMessage(new Message("Ping", "Pong"));
|
||||
var requestHandler = new CompletionRequestHandler(getRequestEventListener(message));
|
||||
expectLlama((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/completion");
|
||||
assertThat(request.getBody())
|
||||
.extracting(
|
||||
"prompt",
|
||||
"n_predict",
|
||||
"stream")
|
||||
.containsExactly(
|
||||
LLAMA.buildPrompt(
|
||||
COMPLETION_SYSTEM_PROMPT,
|
||||
"TEST_PROMPT",
|
||||
conversation.getMessages()),
|
||||
99,
|
||||
true);
|
||||
return List.of(
|
||||
jsonMapResponse("content", "Hel"),
|
||||
jsonMapResponse("content", "lo!"),
|
||||
jsonMapResponse(
|
||||
e("content", ""),
|
||||
e("stop", true)));
|
||||
});
|
||||
|
||||
requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false));
|
||||
|
||||
waitExpecting(() -> "Hello!".equals(message.getResponse()));
|
||||
}
|
||||
|
||||
private CompletionResponseEventListener getRequestEventListener(Message message) {
|
||||
return new CompletionResponseEventListener() {
|
||||
@Override
|
||||
public void handleCompleted(String fullMessage, CallParameters callParameters) {
|
||||
message.setResponse(fullMessage);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
@ -1,145 +0,0 @@
|
|||
package ee.carlrobert.codegpt.completions;
|
||||
|
||||
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA;
|
||||
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML;
|
||||
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
|
||||
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
|
||||
public class PromptTemplateTest {
|
||||
|
||||
private static final String SYSTEM_PROMPT = "TEST_SYSTEM_PROMPT";
|
||||
private static final String USER_PROMPT = "TEST_USER_PROMPT";
|
||||
private static final List<Message> HISTORY = List.of(
|
||||
new Message("TEST_PREV_PROMPT_1", "TEST_PREV_RESPONSE_1"),
|
||||
new Message("TEST_PREV_PROMPT_2", "TEST_PREV_RESPONSE_2"));
|
||||
|
||||
@Test
|
||||
public void shouldBuildLlamaPromptWithHistory() {
|
||||
var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>
|
||||
[INST]TEST_PREV_PROMPT_1[/INST]
|
||||
TEST_PREV_RESPONSE_1
|
||||
[INST]TEST_PREV_PROMPT_2[/INST]
|
||||
TEST_PREV_RESPONSE_2
|
||||
[INST]TEST_USER_PROMPT[/INST]""");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldBuildLlamaPromptWithoutHistory() {
|
||||
var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>
|
||||
[INST]TEST_USER_PROMPT[/INST]""");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldBuildAlpacaPromptWithHistory() {
|
||||
var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
Below is an instruction that describes a task. \
|
||||
Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction
|
||||
TEST_PREV_PROMPT_1
|
||||
|
||||
### Response:
|
||||
TEST_PREV_RESPONSE_1
|
||||
|
||||
### Instruction
|
||||
TEST_PREV_PROMPT_2
|
||||
|
||||
### Response:
|
||||
TEST_PREV_RESPONSE_2
|
||||
|
||||
### Instruction
|
||||
TEST_USER_PROMPT
|
||||
|
||||
### Response:
|
||||
""");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldBuildAlpacaPromptWithoutHistory() {
|
||||
var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
Below is an instruction that describes a task. \
|
||||
Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction
|
||||
TEST_USER_PROMPT
|
||||
|
||||
### Response:
|
||||
""");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldBuildChatMLPromptWithHistory() {
|
||||
var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|im_start|>system
|
||||
TEST_SYSTEM_PROMPT<|im_end|>
|
||||
<|im_start|>user
|
||||
TEST_PREV_PROMPT_1<|im_end|>
|
||||
<|im_start|>assistant
|
||||
TEST_PREV_RESPONSE_1<|im_end|>
|
||||
<|im_start|>user
|
||||
TEST_PREV_PROMPT_2<|im_end|>
|
||||
<|im_start|>assistant
|
||||
TEST_PREV_RESPONSE_2<|im_end|>
|
||||
<|im_start|>user
|
||||
TEST_USER_PROMPT<|im_end|>"""
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldBuildChatMLPromptWithoutHistory() {
|
||||
var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|im_start|>system
|
||||
TEST_SYSTEM_PROMPT<|im_end|>
|
||||
<|im_start|>user
|
||||
TEST_USER_PROMPT<|im_end|>""");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldBuildToRAPromptWithHistory() {
|
||||
var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|user|>
|
||||
TEST_PREV_PROMPT_1
|
||||
<|assistant|>
|
||||
TEST_PREV_RESPONSE_1
|
||||
<|user|>
|
||||
TEST_PREV_PROMPT_2
|
||||
<|assistant|>
|
||||
TEST_PREV_RESPONSE_2
|
||||
<|user|>
|
||||
TEST_USER_PROMPT
|
||||
<|assistant|>"""
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldBuildToRAPromptWithoutHistory() {
|
||||
var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|user|>
|
||||
TEST_USER_PROMPT
|
||||
<|assistant|>"""
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
package ee.carlrobert.codegpt.conversations;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType;
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
|
||||
|
||||
public class ConversationsStateTest extends BasePlatformTestCase {
|
||||
|
||||
public void testStartNewDefaultConversation() {
|
||||
GeneralSettings.getCurrentState().setSelectedService(ServiceType.OPENAI);
|
||||
OpenAISettings.getCurrentState().setModel(OpenAIChatCompletionModel.GPT_3_5.getCode());
|
||||
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
|
||||
assertThat(conversation).isEqualTo(ConversationsState.getCurrentConversation());
|
||||
assertThat(conversation)
|
||||
.extracting("clientCode", "model")
|
||||
.containsExactly("chat.completion", "gpt-3.5-turbo");
|
||||
}
|
||||
|
||||
public void testSaveConversation() {
|
||||
var service = ConversationService.getInstance();
|
||||
var conversation = service.createConversation("chat.completion");
|
||||
service.addConversation(conversation);
|
||||
var message = new Message("TEST_PROMPT");
|
||||
message.setResponse("TEST_RESPONSE");
|
||||
conversation.addMessage(message);
|
||||
|
||||
service.saveConversation(conversation);
|
||||
|
||||
var currentConversation = ConversationsState.getCurrentConversation();
|
||||
assertThat(currentConversation).isNotNull();
|
||||
assertThat(currentConversation.getMessages())
|
||||
.flatExtracting("prompt", "response")
|
||||
.containsExactly("TEST_PROMPT", "TEST_RESPONSE");
|
||||
}
|
||||
|
||||
public void testGetPreviousConversation() {
|
||||
var service = ConversationService.getInstance();
|
||||
var firstConversation = service.startConversation();
|
||||
service.startConversation();
|
||||
|
||||
var previousConversation = service.getPreviousConversation();
|
||||
|
||||
assertThat(previousConversation.isPresent()).isTrue();
|
||||
assertThat(previousConversation.get()).isEqualTo(firstConversation);
|
||||
}
|
||||
|
||||
public void testGetNextConversation() {
|
||||
var service = ConversationService.getInstance();
|
||||
var firstConversation = service.startConversation();
|
||||
var secondConversation = service.startConversation();
|
||||
ConversationsState.getInstance().setCurrentConversation(firstConversation);
|
||||
|
||||
var nextConversation = service.getNextConversation();
|
||||
|
||||
assertThat(nextConversation.isPresent()).isTrue();
|
||||
assertThat(nextConversation.get()).isEqualTo(secondConversation);
|
||||
}
|
||||
|
||||
public void testDeleteSelectedConversation() {
|
||||
var service = ConversationService.getInstance();
|
||||
var firstConversation = service.startConversation();
|
||||
service.startConversation();
|
||||
|
||||
service.deleteSelectedConversation();
|
||||
|
||||
assertThat(ConversationsState.getCurrentConversation()).isEqualTo(firstConversation);
|
||||
assertThat(service.getSortedConversations().size()).isEqualTo(1);
|
||||
assertThat(service.getSortedConversations())
|
||||
.extracting("id")
|
||||
.containsExactly(firstConversation.getId());
|
||||
}
|
||||
|
||||
public void testClearAllConversations() {
|
||||
var service = ConversationService.getInstance();
|
||||
service.startConversation();
|
||||
service.startConversation();
|
||||
|
||||
service.clearAll();
|
||||
|
||||
assertThat(ConversationsState.getCurrentConversation()).isNull();
|
||||
assertThat(service.getSortedConversations().size()).isEqualTo(0);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
package ee.carlrobert.codegpt.settings.state;
|
||||
|
||||
import static ee.carlrobert.codegpt.completions.HuggingFaceModel.CODE_LLAMA_7B_Q3;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
|
||||
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
|
||||
import ee.carlrobert.codegpt.conversations.Conversation;
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType;
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
|
||||
|
||||
public class GeneralSettingsTest extends BasePlatformTestCase {
|
||||
|
||||
public void testOpenAISettingsSync() {
|
||||
var openAISettings = OpenAISettings.getCurrentState();
|
||||
openAISettings.setModel("gpt-3.5-turbo");
|
||||
var conversation = new Conversation();
|
||||
conversation.setModel("gpt-4");
|
||||
conversation.setClientCode("chat.completion");
|
||||
var settings = GeneralSettings.getInstance();
|
||||
|
||||
settings.sync(conversation);
|
||||
|
||||
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.OPENAI);
|
||||
assertThat(openAISettings.getModel()).isEqualTo("gpt-4");
|
||||
}
|
||||
|
||||
public void testAzureSettingsSync() {
|
||||
var settings = GeneralSettings.getInstance();
|
||||
var conversation = new Conversation();
|
||||
conversation.setModel("gpt-4");
|
||||
conversation.setClientCode("azure.chat.completion");
|
||||
|
||||
settings.sync(conversation);
|
||||
|
||||
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.AZURE);
|
||||
}
|
||||
|
||||
public void testYouSettingsSync() {
|
||||
var settings = GeneralSettings.getInstance();
|
||||
var conversation = new Conversation();
|
||||
conversation.setModel("YouCode");
|
||||
conversation.setClientCode("you.chat.completion");
|
||||
|
||||
settings.sync(conversation);
|
||||
|
||||
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.YOU);
|
||||
}
|
||||
|
||||
public void testLlamaSettingsModelPathSync() {
|
||||
var llamaSettings = LlamaSettings.getCurrentState();
|
||||
llamaSettings.setHuggingFaceModel(HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3);
|
||||
var conversation = new Conversation();
|
||||
conversation.setModel("TEST_LLAMA_MODEL_PATH");
|
||||
conversation.setClientCode("llama.chat.completion");
|
||||
var settings = GeneralSettings.getInstance();
|
||||
|
||||
settings.sync(conversation);
|
||||
|
||||
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.LLAMA_CPP);
|
||||
assertThat(llamaSettings.getCustomLlamaModelPath()).isEqualTo("TEST_LLAMA_MODEL_PATH");
|
||||
assertThat(llamaSettings.isUseCustomModel()).isTrue();
|
||||
}
|
||||
|
||||
public void testLlamaSettingsHuggingFaceModelSync() {
|
||||
var llamaSettings = LlamaSettings.getCurrentState();
|
||||
llamaSettings.setHuggingFaceModel(HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3);
|
||||
var conversation = new Conversation();
|
||||
conversation.setModel("CODE_LLAMA_7B_Q3");
|
||||
conversation.setClientCode("llama.chat.completion");
|
||||
var settings = GeneralSettings.getInstance();
|
||||
|
||||
settings.sync(conversation);
|
||||
|
||||
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.LLAMA_CPP);
|
||||
assertThat(llamaSettings.getHuggingFaceModel()).isEqualTo(CODE_LLAMA_7B_Q3);
|
||||
assertThat(llamaSettings.isUseCustomModel()).isFalse();
|
||||
}
|
||||
}
|
||||
|
|
@ -1,415 +0,0 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.chat;
|
||||
|
||||
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
|
||||
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
|
||||
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.e;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap;
|
||||
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
|
||||
import static java.util.Objects.requireNonNull;
|
||||
import static org.apache.http.HttpHeaders.AUTHORIZATION;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import ee.carlrobert.codegpt.CodeGPTKeys;
|
||||
import ee.carlrobert.codegpt.EncodingManager;
|
||||
import ee.carlrobert.codegpt.ReferencedFile;
|
||||
import ee.carlrobert.codegpt.completions.ConversationType;
|
||||
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService;
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
|
||||
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import testsupport.IntegrationTest;
|
||||
|
||||
public class ChatToolWindowTabPanelTest extends IntegrationTest {
|
||||
|
||||
public void testSendingOpenAIMessage() {
|
||||
useOpenAIService();
|
||||
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
|
||||
var message = new Message("Hello!");
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
|
||||
expectOpenAI((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
|
||||
assertThat(request.getBody())
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
List.of(
|
||||
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
|
||||
Map.of("role", "user", "content", "Hello!")));
|
||||
return List.of(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
|
||||
});
|
||||
|
||||
panel.sendMessage(message);
|
||||
|
||||
waitExpecting(() -> {
|
||||
var messages = conversation.getMessages();
|
||||
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
|
||||
});
|
||||
var encodingManager = EncodingManager.getInstance();
|
||||
assertThat(panel.getTokenDetails()).extracting(
|
||||
"systemPromptTokens",
|
||||
"conversationTokens",
|
||||
"userPromptTokens",
|
||||
"highlightedTokens")
|
||||
.containsExactly(
|
||||
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
|
||||
encodingManager.countTokens(message.getPrompt()),
|
||||
0,
|
||||
0);
|
||||
assertThat(panel.getConversation())
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.getId(),
|
||||
conversation.getModel(),
|
||||
conversation.getClientCode(),
|
||||
false);
|
||||
var messages = panel.getConversation().getMessages();
|
||||
assertThat(messages.size()).isOne();
|
||||
assertThat(messages.get(0))
|
||||
.extracting("id", "prompt", "response")
|
||||
.containsExactly(message.getId(), message.getPrompt(), message.getResponse());
|
||||
}
|
||||
|
||||
public void testSendingOpenAIMessageWithReferencedContext() {
|
||||
getProject().putUserData(CodeGPTKeys.SELECTED_FILES, List.of(
|
||||
new ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"),
|
||||
new ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
|
||||
new ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")));
|
||||
useOpenAIService();
|
||||
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
|
||||
var message = new Message("TEST_MESSAGE");
|
||||
message.setUserMessage("TEST_MESSAGE");
|
||||
message.setReferencedFilePaths(
|
||||
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
|
||||
expectOpenAI((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
|
||||
assertThat(request.getBody())
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
List.of(
|
||||
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
|
||||
Map.of("role", "user", "content",
|
||||
"""
|
||||
Use the following context to answer question at the end:
|
||||
|
||||
File Path: TEST_FILE_PATH_1
|
||||
File Content:
|
||||
```TEST_FILE_NAME_1
|
||||
TEST_FILE_CONTENT_1
|
||||
```
|
||||
|
||||
File Path: TEST_FILE_PATH_2
|
||||
File Content:
|
||||
```TEST_FILE_NAME_2
|
||||
TEST_FILE_CONTENT_2
|
||||
```
|
||||
|
||||
File Path: TEST_FILE_PATH_3
|
||||
File Content:
|
||||
```TEST_FILE_NAME_3
|
||||
TEST_FILE_CONTENT_3
|
||||
```
|
||||
|
||||
Question: TEST_MESSAGE""")));
|
||||
return List.of(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
|
||||
});
|
||||
|
||||
panel.sendMessage(message);
|
||||
|
||||
waitExpecting(() -> {
|
||||
var messages = conversation.getMessages();
|
||||
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
|
||||
});
|
||||
var encodingManager = EncodingManager.getInstance();
|
||||
assertThat(panel.getTokenDetails()).extracting(
|
||||
"systemPromptTokens",
|
||||
"conversationTokens",
|
||||
"userPromptTokens",
|
||||
"highlightedTokens")
|
||||
.containsExactly(
|
||||
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
|
||||
encodingManager.countTokens(message.getPrompt()),
|
||||
0,
|
||||
0);
|
||||
assertThat(panel.getConversation())
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.getId(),
|
||||
conversation.getModel(),
|
||||
conversation.getClientCode(),
|
||||
false);
|
||||
var messages = panel.getConversation().getMessages();
|
||||
assertThat(messages.size()).isOne();
|
||||
assertThat(messages.get(0))
|
||||
.extracting("id", "prompt", "response", "referencedFilePaths")
|
||||
.containsExactly(
|
||||
message.getId(),
|
||||
message.getPrompt(),
|
||||
message.getResponse(),
|
||||
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
|
||||
}
|
||||
|
||||
public void testSendingOpenAIMessageWithImage() {
|
||||
var testImagePath = requireNonNull(getClass().getResource("/images/test-image.png")).getPath();
|
||||
getProject().putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath);
|
||||
useOpenAIService("gpt-4-vision-preview");
|
||||
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
|
||||
var message = new Message("TEST_MESSAGE");
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
|
||||
expectOpenAI((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
|
||||
try {
|
||||
var testImageUrl = "data:image/png;base64,"
|
||||
+ Base64.getEncoder().encodeToString(Files.readAllBytes(Path.of(testImagePath)));
|
||||
assertThat(request.getBody())
|
||||
.extracting("model", "messages")
|
||||
.containsExactly(
|
||||
"gpt-4-vision-preview",
|
||||
List.of(
|
||||
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
|
||||
Map.of("role", "user", "content", List.of(
|
||||
Map.of(
|
||||
"type", "image_url",
|
||||
"image_url", Map.of("url", testImageUrl)),
|
||||
Map.of("type", "text", "text", "TEST_MESSAGE")
|
||||
))));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return List.of(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
|
||||
});
|
||||
|
||||
panel.sendMessage(message);
|
||||
|
||||
waitExpecting(() -> {
|
||||
var messages = conversation.getMessages();
|
||||
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
|
||||
});
|
||||
var encodingManager = EncodingManager.getInstance();
|
||||
assertThat(panel.getTokenDetails()).extracting(
|
||||
"systemPromptTokens",
|
||||
"conversationTokens",
|
||||
"userPromptTokens",
|
||||
"highlightedTokens")
|
||||
.containsExactly(
|
||||
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
|
||||
encodingManager.countTokens(message.getPrompt()),
|
||||
0,
|
||||
0);
|
||||
assertThat(panel.getConversation())
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.getId(),
|
||||
conversation.getModel(),
|
||||
conversation.getClientCode(),
|
||||
false);
|
||||
var messages = panel.getConversation().getMessages();
|
||||
assertThat(messages.size()).isOne();
|
||||
assertThat(messages.get(0))
|
||||
.extracting("id", "prompt", "response", "imageFilePath")
|
||||
.containsExactly(
|
||||
message.getId(),
|
||||
message.getPrompt(),
|
||||
message.getResponse(),
|
||||
message.getImageFilePath());
|
||||
}
|
||||
|
||||
public void testFixCompileErrorsWithOpenAIService() {
|
||||
getProject().putUserData(CodeGPTKeys.SELECTED_FILES, List.of(
|
||||
new ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"),
|
||||
new ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
|
||||
new ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")));
|
||||
useOpenAIService();
|
||||
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
|
||||
var message = new Message("TEST_MESSAGE");
|
||||
message.setUserMessage("TEST_MESSAGE");
|
||||
message.setReferencedFilePaths(
|
||||
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
|
||||
expectOpenAI((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
|
||||
assertThat(request.getMethod()).isEqualTo("POST");
|
||||
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
|
||||
assertThat(request.getBody())
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
List.of(
|
||||
Map.of("role", "system", "content", FIX_COMPILE_ERRORS_SYSTEM_PROMPT),
|
||||
Map.of("role", "user", "content",
|
||||
"""
|
||||
Use the following context to answer question at the end:
|
||||
|
||||
File Path: TEST_FILE_PATH_1
|
||||
File Content:
|
||||
```TEST_FILE_NAME_1
|
||||
TEST_FILE_CONTENT_1
|
||||
```
|
||||
|
||||
File Path: TEST_FILE_PATH_2
|
||||
File Content:
|
||||
```TEST_FILE_NAME_2
|
||||
TEST_FILE_CONTENT_2
|
||||
```
|
||||
|
||||
File Path: TEST_FILE_PATH_3
|
||||
File Content:
|
||||
```TEST_FILE_NAME_3
|
||||
TEST_FILE_CONTENT_3
|
||||
```
|
||||
|
||||
Question: TEST_MESSAGE""")));
|
||||
return List.of(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
|
||||
});
|
||||
|
||||
panel.sendMessage(message, ConversationType.FIX_COMPILE_ERRORS);
|
||||
|
||||
waitExpecting(() -> {
|
||||
var messages = conversation.getMessages();
|
||||
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
|
||||
});
|
||||
var encodingManager = EncodingManager.getInstance();
|
||||
assertThat(panel.getTokenDetails()).extracting(
|
||||
"systemPromptTokens",
|
||||
"conversationTokens",
|
||||
"userPromptTokens",
|
||||
"highlightedTokens")
|
||||
.containsExactly(
|
||||
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
|
||||
encodingManager.countTokens(message.getPrompt()),
|
||||
0,
|
||||
0);
|
||||
assertThat(panel.getConversation())
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.getId(),
|
||||
conversation.getModel(),
|
||||
conversation.getClientCode(),
|
||||
false);
|
||||
var messages = panel.getConversation().getMessages();
|
||||
assertThat(messages.size()).isOne();
|
||||
assertThat(messages.get(0))
|
||||
.extracting("id", "prompt", "response", "referencedFilePaths")
|
||||
.containsExactly(
|
||||
message.getId(),
|
||||
message.getPrompt(),
|
||||
message.getResponse(),
|
||||
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
|
||||
}
|
||||
|
||||
public void testSendingLlamaMessage() {
|
||||
useLlamaService();
|
||||
var configurationState = ConfigurationSettings.getCurrentState();
|
||||
configurationState.setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
|
||||
configurationState.setMaxTokens(1000);
|
||||
configurationState.setTemperature(0.1);
|
||||
var llamaSettings = LlamaSettings.getCurrentState();
|
||||
llamaSettings.setUseCustomModel(false);
|
||||
llamaSettings.setHuggingFaceModel(HuggingFaceModel.CODE_LLAMA_7B_Q4);
|
||||
llamaSettings.setTopK(30);
|
||||
llamaSettings.setTopP(0.8);
|
||||
llamaSettings.setMinP(0.03);
|
||||
llamaSettings.setRepeatPenalty(1.3);
|
||||
var message = new Message("TEST_PROMPT");
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
|
||||
expectLlama((StreamHttpExchange) request -> {
|
||||
assertThat(request.getUri().getPath()).isEqualTo("/completion");
|
||||
assertThat(request.getBody())
|
||||
.extracting(
|
||||
"prompt",
|
||||
"n_predict",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"min_p",
|
||||
"repeat_penalty")
|
||||
.containsExactly(
|
||||
LLAMA.buildPrompt(
|
||||
COMPLETION_SYSTEM_PROMPT,
|
||||
"TEST_PROMPT",
|
||||
conversation.getMessages()),
|
||||
configurationState.getMaxTokens(),
|
||||
true,
|
||||
configurationState.getTemperature(),
|
||||
llamaSettings.getTopK(),
|
||||
llamaSettings.getTopP(),
|
||||
llamaSettings.getMinP(),
|
||||
llamaSettings.getRepeatPenalty());
|
||||
return List.of(
|
||||
jsonMapResponse("content", "Hel"),
|
||||
jsonMapResponse("content", "lo!"),
|
||||
jsonMapResponse(
|
||||
e("content", ""),
|
||||
e("stop", true)));
|
||||
});
|
||||
|
||||
panel.sendMessage(message, ConversationType.DEFAULT);
|
||||
|
||||
waitExpecting(() -> {
|
||||
var messages = conversation.getMessages();
|
||||
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
|
||||
});
|
||||
assertThat(panel.getConversation())
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.getId(),
|
||||
conversation.getModel(),
|
||||
conversation.getClientCode(),
|
||||
false);
|
||||
var messages = panel.getConversation().getMessages();
|
||||
assertThat(messages.size()).isOne();
|
||||
assertThat(messages.get(0))
|
||||
.extracting("id", "prompt", "response")
|
||||
.containsExactly(message.getId(), message.getPrompt(), message.getResponse());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.chat;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import com.intellij.openapi.util.Disposer;
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService;
|
||||
import ee.carlrobert.codegpt.conversations.message.Message;
|
||||
|
||||
public class ChatToolWindowTabbedPaneTest extends BasePlatformTestCase {
|
||||
|
||||
public void testClearAllTabs() {
|
||||
var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable());
|
||||
tabbedPane.addNewTab(createNewTabPanel());
|
||||
|
||||
tabbedPane.clearAll();
|
||||
|
||||
assertThat(tabbedPane.getActiveTabMapping()).isEmpty();
|
||||
}
|
||||
|
||||
|
||||
public void testAddingNewTabs() {
|
||||
var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable());
|
||||
|
||||
tabbedPane.addNewTab(createNewTabPanel());
|
||||
tabbedPane.addNewTab(createNewTabPanel());
|
||||
tabbedPane.addNewTab(createNewTabPanel());
|
||||
|
||||
assertThat(tabbedPane.getActiveTabMapping().keySet())
|
||||
.containsExactly("Chat 1", "Chat 2", "Chat 3");
|
||||
}
|
||||
|
||||
public void testResetCurrentlyActiveTabPanel() {
|
||||
var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable());
|
||||
var conversation = ConversationService.getInstance().startConversation();
|
||||
conversation.addMessage(new Message("TEST_PROMPT", "TEST_RESPONSE"));
|
||||
tabbedPane.addNewTab(new ChatToolWindowTabPanel(getProject(), conversation));
|
||||
|
||||
tabbedPane.resetCurrentlyActiveTabPanel(getProject());
|
||||
|
||||
var tabPanel = tabbedPane.getActiveTabMapping().get("Chat 1");
|
||||
assertThat(tabPanel.getConversation().getMessages()).isEmpty();
|
||||
}
|
||||
|
||||
private ChatToolWindowTabPanel createNewTabPanel() {
|
||||
return new ChatToolWindowTabPanel(
|
||||
getProject(),
|
||||
ConversationService.getInstance().startConversation());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
package testsupport;
|
||||
|
||||
import com.intellij.openapi.util.Key;
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
|
||||
import ee.carlrobert.codegpt.CodeGPTKeys;
|
||||
import ee.carlrobert.llm.client.mixin.ExternalServiceTestMixin;
|
||||
import java.util.Collections;
|
||||
import testsupport.mixin.ShortcutsTestMixin;
|
||||
|
||||
public class IntegrationTest extends BasePlatformTestCase implements
|
||||
ExternalServiceTestMixin,
|
||||
ShortcutsTestMixin {
|
||||
|
||||
static {
|
||||
ExternalServiceTestMixin.init();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void tearDown() throws Exception {
|
||||
ExternalServiceTestMixin.clearAll();
|
||||
clearKeys();
|
||||
super.tearDown();
|
||||
}
|
||||
|
||||
private void clearKeys() {
|
||||
putUserData(CodeGPTKeys.SELECTED_FILES, Collections.emptyList());
|
||||
putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, "");
|
||||
putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, "");
|
||||
}
|
||||
|
||||
private <T> void putUserData(Key<T> key, T value) {
|
||||
getProject().putUserData(key, value);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
package testsupport.mixin;
|
||||
|
||||
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_OPENAI_API_KEY;
|
||||
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY;
|
||||
|
||||
import com.intellij.testFramework.PlatformTestUtil;
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore;
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType;
|
||||
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
|
||||
import java.util.function.BooleanSupplier;
|
||||
|
||||
public interface ShortcutsTestMixin {
|
||||
|
||||
default void useOpenAIService() {
|
||||
useOpenAIService("gpt-4");
|
||||
}
|
||||
|
||||
default void useOpenAIService(String model) {
|
||||
GeneralSettings.getCurrentState().setSelectedService(ServiceType.OPENAI);
|
||||
CredentialsStore.INSTANCE.setCredential(OPENAI_API_KEY, "TEST_API_KEY");
|
||||
OpenAISettings.getCurrentState().setModel(model);
|
||||
}
|
||||
|
||||
default void useAzureService() {
|
||||
GeneralSettings.getCurrentState().setSelectedService(ServiceType.AZURE);
|
||||
CredentialsStore.INSTANCE.setCredential(AZURE_OPENAI_API_KEY, "TEST_API_KEY");
|
||||
var azureSettings = AzureSettings.getCurrentState();
|
||||
azureSettings.setResourceName("TEST_RESOURCE_NAME");
|
||||
azureSettings.setApiVersion("TEST_API_VERSION");
|
||||
azureSettings.setDeploymentId("TEST_DEPLOYMENT_ID");
|
||||
}
|
||||
|
||||
default void useYouService() {
|
||||
GeneralSettings.getCurrentState().setSelectedService(ServiceType.YOU);
|
||||
}
|
||||
|
||||
default void useLlamaService() {
|
||||
GeneralSettings.getCurrentState().setSelectedService(ServiceType.LLAMA_CPP);
|
||||
LlamaSettings.getCurrentState().setServerPort(null);
|
||||
}
|
||||
|
||||
default void waitExpecting(BooleanSupplier condition) {
|
||||
PlatformTestUtil.waitWithEventsDispatching(
|
||||
"Waiting for message response timed out or did not meet expected conditions",
|
||||
condition,
|
||||
5);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package ee.carlrobert.codegpt.codecompletions
|
||||
|
||||
import com.intellij.openapi.editor.VisualPosition
|
||||
import ee.carlrobert.codegpt.CodeGPTKeys.PREVIOUS_INLAY_TEXT
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
|
||||
import ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent
|
||||
import ee.carlrobert.llm.client.http.RequestEntity
|
||||
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.e
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
|
||||
class CodeCompletionServiceTest : IntegrationTest() {
|
||||
|
||||
private val cursorPosition = VisualPosition(3, 0)
|
||||
|
||||
fun testFetchCodeCompletionLlama() {
|
||||
useLlamaService()
|
||||
LlamaSettings.getCurrentState().isCodeCompletionsEnabled = true
|
||||
myFixture.configureByText(
|
||||
"CompletionTest.java",
|
||||
getResourceContent("/codecompletions/code-completion-file.txt")
|
||||
)
|
||||
myFixture.editor.caretModel.moveToVisualPosition(cursorPosition)
|
||||
val expectedCompletion = "TEST_OUTPUT"
|
||||
val prefix = """
|
||||
${"z".repeat(245)}
|
||||
[INPUT]
|
||||
c
|
||||
""".trimIndent() // 128 tokens
|
||||
val suffix = """
|
||||
|
||||
[\INPUT]
|
||||
${"z".repeat(247)}
|
||||
""".trimIndent() // 128 tokens
|
||||
expectLlama(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/completion")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.body)
|
||||
.extracting("prompt")
|
||||
.isEqualTo(InfillPromptTemplate.LLAMA.buildPrompt(prefix, suffix))
|
||||
listOf(jsonMapResponse(e("content", expectedCompletion), e("stop", true)))
|
||||
})
|
||||
|
||||
myFixture.type('c')
|
||||
|
||||
waitExpecting { "TEST_OUTPUT" == PREVIOUS_INLAY_TEXT[myFixture.editor] }
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
package ee.carlrobert.codegpt.completions
|
||||
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.groups.Tuple
|
||||
import testsupport.IntegrationTest
|
||||
|
||||
class CompletionRequestProviderTest : IntegrationTest() {
|
||||
|
||||
fun testChatCompletionRequestWithSystemPromptOverride() {
|
||||
setCredential(CredentialKey.OPENAI_API_KEY, "TEST_API_KEY")
|
||||
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val firstMessage = createDummyMessage(500)
|
||||
val secondMessage = createDummyMessage(250)
|
||||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.code,
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
false))
|
||||
|
||||
assertThat(request.messages)
|
||||
.extracting("role", "content")
|
||||
.containsExactly(
|
||||
Tuple.tuple("system", "TEST_SYSTEM_PROMPT"),
|
||||
Tuple.tuple("user", "TEST_PROMPT"),
|
||||
Tuple.tuple("assistant", firstMessage.response),
|
||||
Tuple.tuple("user", "TEST_PROMPT"),
|
||||
Tuple.tuple("assistant", secondMessage.response),
|
||||
Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT"))
|
||||
}
|
||||
|
||||
fun testChatCompletionRequestWithoutSystemPromptOverride() {
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val firstMessage = createDummyMessage(500)
|
||||
val secondMessage = createDummyMessage(250)
|
||||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.code,
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
false))
|
||||
|
||||
assertThat(request.messages)
|
||||
.extracting("role", "content")
|
||||
.containsExactly(
|
||||
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
|
||||
Tuple.tuple("user", "TEST_PROMPT"),
|
||||
Tuple.tuple("assistant", firstMessage.response),
|
||||
Tuple.tuple("user", "TEST_PROMPT"),
|
||||
Tuple.tuple("assistant", secondMessage.response),
|
||||
Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT"))
|
||||
}
|
||||
|
||||
fun testChatCompletionRequestRetry() {
|
||||
ConfigurationSettings.getCurrentState().systemPrompt = CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500)
|
||||
val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250)
|
||||
conversation.addMessage(firstMessage)
|
||||
conversation.addMessage(secondMessage)
|
||||
|
||||
val request = CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.code,
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
secondMessage,
|
||||
true))
|
||||
|
||||
assertThat(request.messages)
|
||||
.extracting("role", "content")
|
||||
.containsExactly(
|
||||
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
|
||||
Tuple.tuple("user", "FIRST_TEST_PROMPT"),
|
||||
Tuple.tuple("assistant", firstMessage.response),
|
||||
Tuple.tuple("user", "SECOND_TEST_PROMPT"))
|
||||
}
|
||||
|
||||
fun testReducedChatCompletionRequest() {
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
conversation.addMessage(createDummyMessage(50))
|
||||
conversation.addMessage(createDummyMessage(100))
|
||||
conversation.addMessage(createDummyMessage(150))
|
||||
conversation.addMessage(createDummyMessage(1000))
|
||||
val remainingMessage = createDummyMessage(2000)
|
||||
conversation.addMessage(remainingMessage)
|
||||
conversation.discardTokenLimits()
|
||||
|
||||
val request = CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.code,
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
Message("TEST_CHAT_COMPLETION_PROMPT"),
|
||||
false))
|
||||
|
||||
assertThat(request.messages)
|
||||
.extracting("role", "content")
|
||||
.containsExactly(
|
||||
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
|
||||
Tuple.tuple("user", "TEST_PROMPT"),
|
||||
Tuple.tuple("assistant", remainingMessage.response),
|
||||
Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT"))
|
||||
}
|
||||
|
||||
fun testTotalUsageExceededException() {
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
conversation.addMessage(createDummyMessage(1500))
|
||||
conversation.addMessage(createDummyMessage(1500))
|
||||
conversation.addMessage(createDummyMessage(1500))
|
||||
|
||||
assertThrows(TotalUsageExceededException::class.java) {
|
||||
CompletionRequestProvider(conversation)
|
||||
.buildOpenAIChatCompletionRequest(
|
||||
OpenAIChatCompletionModel.GPT_3_5.code,
|
||||
CallParameters(
|
||||
conversation,
|
||||
ConversationType.DEFAULT,
|
||||
createDummyMessage(100),
|
||||
false)) }
|
||||
}
|
||||
|
||||
private fun createDummyMessage(tokenSize: Int): Message {
|
||||
return createDummyMessage("TEST_PROMPT", tokenSize)
|
||||
}
|
||||
|
||||
private fun createDummyMessage(prompt: String, tokenSize: Int): Message {
|
||||
val message = Message(prompt)
|
||||
// 'zz' = 1 token, prompt = 6 tokens, 7 tokens per message (GPT-3),
|
||||
message.response = "zz".repeat((tokenSize) - 6 - 7)
|
||||
return message
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
package ee.carlrobert.codegpt.completions
|
||||
|
||||
import ee.carlrobert.codegpt.CodeGPTPlugin
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.llm.client.http.RequestEntity
|
||||
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.e
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.jsonArray
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.jsonMap
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse
|
||||
import org.apache.http.HttpHeaders
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
|
||||
class DefaultCompletionRequestHandlerTest : IntegrationTest() {
|
||||
|
||||
fun testOpenAIChatCompletionCall() {
|
||||
useOpenAIService()
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")))
|
||||
listOf(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
|
||||
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testAzureChatCompletionCall() {
|
||||
useAzureService()
|
||||
val conversationService = ConversationService.getInstance()
|
||||
val prevMessage = Message("TEST_PREV_PROMPT")
|
||||
prevMessage.response = "TEST_PREV_RESPONSE"
|
||||
val conversation = conversationService.startConversation()
|
||||
conversation.addMessage(prevMessage)
|
||||
conversationService.saveConversation(conversation)
|
||||
expectAzure(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo(
|
||||
"/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions")
|
||||
assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION")
|
||||
assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY")
|
||||
assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt")
|
||||
assertThat(request.body)
|
||||
.extracting("messages")
|
||||
.isEqualTo(
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
|
||||
mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"),
|
||||
mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"),
|
||||
mapOf("role" to "user", "content" to "TEST_PROMPT")))
|
||||
listOf(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
val message = Message("TEST_PROMPT")
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
|
||||
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testYouChatCompletionCall() {
|
||||
useYouService()
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
conversation.addMessage(Message("Ping", "Pong"))
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
expectYou(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/api/streamingSearch")
|
||||
assertThat(request.method).isEqualTo("GET")
|
||||
assertThat(request.uri.path).isEqualTo("/api/streamingSearch")
|
||||
assertThat(request.uri.query).isEqualTo(
|
||||
"q=TEST_PROMPT&"
|
||||
+ "page=1&"
|
||||
+ "cfr=CodeGPT&"
|
||||
+ "count=10&"
|
||||
+ "safeSearch=WebPages,Translations,TimeZone,Computation,RelatedSearches&"
|
||||
+ "domain=youchat&"
|
||||
+ "selectedChatMode=default&"
|
||||
+ "chat=[{\"question\":\"Ping\",\"answer\":\"Pong\"}]&"
|
||||
+ "utm_source=ide&"
|
||||
+ "utm_medium=jetbrains&"
|
||||
+ "utm_campaign=" + CodeGPTPlugin.getVersion() + "&"
|
||||
+ "utm_content=CodeGPT")
|
||||
assertThat(request.headers)
|
||||
.flatExtracting("Accept", "Connection", "User-agent", "Cookie")
|
||||
.containsExactly(
|
||||
"text/event-stream",
|
||||
"Keep-Alive",
|
||||
"youide CodeGPT",
|
||||
"safesearch_guest=Moderate; "
|
||||
+ "youpro_subscription=true; "
|
||||
+ "you_subscription=free; "
|
||||
+ "stytch_session=; "
|
||||
+ "ydc_stytch_session=; "
|
||||
+ "stytch_session_jwt=; "
|
||||
+ "ydc_stytch_session_jwt=; "
|
||||
+ "eg4=false; "
|
||||
+ "__cf_bm=aN2b3pQMH8XADeMB7bg9s1bJ_bfXBcCHophfOGRg6g0-1693601599-0-"
|
||||
+ "AWIt5Mr4Y3xQI4mIJ1lSf4+vijWKDobrty8OopDeBxY+NABe0MRFidF3dCUoWjRt8"
|
||||
+ "SVMvBZPI3zkOgcRs7Mz3yazd7f7c58HwW5Xg9jdBjNg;")
|
||||
listOf(
|
||||
jsonMapResponse("youChatToken", "Hel"),
|
||||
jsonMapResponse("youChatToken", "lo"),
|
||||
jsonMapResponse("youChatToken", "!"))
|
||||
})
|
||||
|
||||
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
fun testLlamaChatCompletionCall() {
|
||||
useLlamaService()
|
||||
ConfigurationSettings.getCurrentState().maxTokens = 99
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
conversation.addMessage(Message("Ping", "Pong"))
|
||||
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
|
||||
expectLlama(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/completion")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"prompt",
|
||||
"n_predict",
|
||||
"stream")
|
||||
.containsExactly(
|
||||
LLAMA.buildPrompt(
|
||||
COMPLETION_SYSTEM_PROMPT,
|
||||
"TEST_PROMPT",
|
||||
conversation.messages),
|
||||
99,
|
||||
true)
|
||||
listOf<String?>(
|
||||
jsonMapResponse("content", "Hel"),
|
||||
jsonMapResponse("content", "lo!"),
|
||||
jsonMapResponse(
|
||||
e("content", ""),
|
||||
e("stop", true)))
|
||||
})
|
||||
|
||||
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
|
||||
|
||||
waitExpecting { "Hello!" == message.response }
|
||||
}
|
||||
|
||||
private fun getRequestEventListener(message: Message): CompletionResponseEventListener {
|
||||
return object : CompletionResponseEventListener {
|
||||
override fun handleCompleted(fullMessage: String, callParameters: CallParameters) {
|
||||
message.response = fullMessage
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
package ee.carlrobert.codegpt.completions
|
||||
|
||||
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA
|
||||
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML
|
||||
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
|
||||
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
|
||||
class PromptTemplateTest {
|
||||
|
||||
@Test
|
||||
fun shouldBuildLlamaPromptWithHistory() {
|
||||
val prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>
|
||||
[INST]TEST_PREV_PROMPT_1[/INST]
|
||||
TEST_PREV_RESPONSE_1
|
||||
[INST]TEST_PREV_PROMPT_2[/INST]
|
||||
TEST_PREV_RESPONSE_2
|
||||
[INST]TEST_USER_PROMPT[/INST]
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun shouldBuildLlamaPromptWithoutHistory() {
|
||||
val prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>
|
||||
[INST]TEST_USER_PROMPT[/INST]
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun shouldBuildAlpacaPromptWithHistory() {
|
||||
val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction
|
||||
TEST_PREV_PROMPT_1
|
||||
|
||||
### Response:
|
||||
TEST_PREV_RESPONSE_1
|
||||
|
||||
### Instruction
|
||||
TEST_PREV_PROMPT_2
|
||||
|
||||
### Response:
|
||||
TEST_PREV_RESPONSE_2
|
||||
|
||||
### Instruction
|
||||
TEST_USER_PROMPT
|
||||
|
||||
### Response:
|
||||
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun shouldBuildAlpacaPromptWithoutHistory() {
|
||||
val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction
|
||||
TEST_USER_PROMPT
|
||||
|
||||
### Response:
|
||||
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun shouldBuildChatMLPromptWithHistory() {
|
||||
val prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|im_start|>system
|
||||
TEST_SYSTEM_PROMPT<|im_end|>
|
||||
<|im_start|>user
|
||||
TEST_PREV_PROMPT_1<|im_end|>
|
||||
<|im_start|>assistant
|
||||
TEST_PREV_RESPONSE_1<|im_end|>
|
||||
<|im_start|>user
|
||||
TEST_PREV_PROMPT_2<|im_end|>
|
||||
<|im_start|>assistant
|
||||
TEST_PREV_RESPONSE_2<|im_end|>
|
||||
<|im_start|>user
|
||||
TEST_USER_PROMPT<|im_end|>
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun shouldBuildChatMLPromptWithoutHistory() {
|
||||
val prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|im_start|>system
|
||||
TEST_SYSTEM_PROMPT<|im_end|>
|
||||
<|im_start|>user
|
||||
TEST_USER_PROMPT<|im_end|>
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun shouldBuildToRAPromptWithHistory() {
|
||||
val prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|user|>
|
||||
TEST_PREV_PROMPT_1
|
||||
<|assistant|>
|
||||
TEST_PREV_RESPONSE_1
|
||||
<|user|>
|
||||
TEST_PREV_PROMPT_2
|
||||
<|assistant|>
|
||||
TEST_PREV_RESPONSE_2
|
||||
<|user|>
|
||||
TEST_USER_PROMPT
|
||||
<|assistant|>
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun shouldBuildToRAPromptWithoutHistory() {
|
||||
val prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
|
||||
|
||||
assertThat(prompt).isEqualTo("""
|
||||
<|user|>
|
||||
TEST_USER_PROMPT
|
||||
<|assistant|>
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val SYSTEM_PROMPT = "TEST_SYSTEM_PROMPT"
|
||||
private const val USER_PROMPT = "TEST_USER_PROMPT"
|
||||
private val HISTORY: List<Message> = listOf(
|
||||
Message("TEST_PREV_PROMPT_1", "TEST_PREV_RESPONSE_1"),
|
||||
Message("TEST_PREV_PROMPT_2", "TEST_PREV_RESPONSE_2")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
package ee.carlrobert.codegpt.conversations
|
||||
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
|
||||
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
|
||||
class ConversationsStateTest : BasePlatformTestCase() {
|
||||
|
||||
fun testStartNewDefaultConversation() {
|
||||
GeneralSettings.getCurrentState().selectedService = ServiceType.OPENAI
|
||||
OpenAISettings.getCurrentState().model = OpenAIChatCompletionModel.GPT_3_5.code
|
||||
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
|
||||
assertThat(conversation).isEqualTo(ConversationsState.getCurrentConversation())
|
||||
assertThat(conversation)
|
||||
.extracting("clientCode", "model")
|
||||
.containsExactly("chat.completion", "gpt-3.5-turbo")
|
||||
}
|
||||
|
||||
fun testSaveConversation() {
|
||||
val service = ConversationService.getInstance()
|
||||
val conversation = service.createConversation("chat.completion")
|
||||
service.addConversation(conversation)
|
||||
val message = Message("TEST_PROMPT")
|
||||
message.response = "TEST_RESPONSE"
|
||||
conversation.addMessage(message)
|
||||
|
||||
service.saveConversation(conversation)
|
||||
|
||||
val currentConversation = ConversationsState.getCurrentConversation()
|
||||
assertThat(currentConversation).isNotNull()
|
||||
assertThat(currentConversation!!.messages)
|
||||
.flatExtracting("prompt", "response")
|
||||
.containsExactly("TEST_PROMPT", "TEST_RESPONSE")
|
||||
}
|
||||
|
||||
fun testGetPreviousConversation() {
|
||||
val service = ConversationService.getInstance()
|
||||
val firstConversation = service.startConversation()
|
||||
service.startConversation()
|
||||
|
||||
val previousConversation = service.previousConversation
|
||||
|
||||
assertThat(previousConversation.isPresent).isTrue()
|
||||
assertThat(previousConversation.get()).isEqualTo(firstConversation)
|
||||
}
|
||||
|
||||
fun testGetNextConversation() {
|
||||
val service = ConversationService.getInstance()
|
||||
val firstConversation = service.startConversation()
|
||||
val secondConversation = service.startConversation()
|
||||
ConversationsState.getInstance().setCurrentConversation(firstConversation)
|
||||
|
||||
val nextConversation = service.nextConversation
|
||||
|
||||
assertThat(nextConversation.isPresent).isTrue()
|
||||
assertThat(nextConversation.get()).isEqualTo(secondConversation)
|
||||
}
|
||||
|
||||
fun testDeleteSelectedConversation() {
|
||||
val service = ConversationService.getInstance()
|
||||
val firstConversation = service.startConversation()
|
||||
service.startConversation()
|
||||
|
||||
service.deleteSelectedConversation()
|
||||
|
||||
assertThat(ConversationsState.getCurrentConversation()).isEqualTo(firstConversation)
|
||||
assertThat(service.sortedConversations.size).isEqualTo(1)
|
||||
assertThat(service.sortedConversations)
|
||||
.extracting("id")
|
||||
.containsExactly(firstConversation.id)
|
||||
}
|
||||
|
||||
fun testClearAllConversations() {
|
||||
val service = ConversationService.getInstance()
|
||||
service.startConversation()
|
||||
service.startConversation()
|
||||
|
||||
service.clearAll()
|
||||
|
||||
assertThat(ConversationsState.getCurrentConversation()).isNull()
|
||||
assertThat(service.sortedConversations.size).isEqualTo(0)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
package ee.carlrobert.codegpt.settings.state
|
||||
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase
|
||||
import ee.carlrobert.codegpt.completions.HuggingFaceModel
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
|
||||
class GeneralSettingsTest : BasePlatformTestCase() {
|
||||
|
||||
fun testOpenAISettingsSync() {
|
||||
val openAISettings = OpenAISettings.getCurrentState()
|
||||
openAISettings.model = "gpt-3.5-turbo"
|
||||
val conversation = Conversation()
|
||||
conversation.model = "gpt-4"
|
||||
conversation.clientCode = "chat.completion"
|
||||
val settings = GeneralSettings.getInstance()
|
||||
|
||||
settings.sync(conversation)
|
||||
|
||||
assertThat(settings.state.selectedService).isEqualTo(ServiceType.OPENAI)
|
||||
assertThat(openAISettings.model).isEqualTo("gpt-4")
|
||||
}
|
||||
|
||||
fun testAzureSettingsSync() {
|
||||
val settings = GeneralSettings.getInstance()
|
||||
val conversation = Conversation()
|
||||
conversation.model = "gpt-4"
|
||||
conversation.clientCode = "azure.chat.completion"
|
||||
|
||||
settings.sync(conversation)
|
||||
|
||||
assertThat(settings.state.selectedService).isEqualTo(ServiceType.AZURE)
|
||||
}
|
||||
|
||||
fun testYouSettingsSync() {
|
||||
val settings = GeneralSettings.getInstance()
|
||||
val conversation = Conversation()
|
||||
conversation.model = "YouCode"
|
||||
conversation.clientCode = "you.chat.completion"
|
||||
|
||||
settings.sync(conversation)
|
||||
|
||||
assertThat(settings.state.selectedService).isEqualTo(ServiceType.YOU)
|
||||
}
|
||||
|
||||
fun testLlamaSettingsModelPathSync() {
|
||||
val llamaSettings = LlamaSettings.getCurrentState()
|
||||
llamaSettings.huggingFaceModel = HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3
|
||||
val conversation = Conversation()
|
||||
conversation.model = "TEST_LLAMA_MODEL_PATH"
|
||||
conversation.clientCode = "llama.chat.completion"
|
||||
val settings = GeneralSettings.getInstance()
|
||||
|
||||
settings.sync(conversation)
|
||||
|
||||
assertThat(settings.state.selectedService).isEqualTo(ServiceType.LLAMA_CPP)
|
||||
assertThat(llamaSettings.customLlamaModelPath).isEqualTo("TEST_LLAMA_MODEL_PATH")
|
||||
assertThat(llamaSettings.isUseCustomModel).isTrue()
|
||||
}
|
||||
|
||||
fun testLlamaSettingsHuggingFaceModelSync() {
|
||||
val llamaSettings = LlamaSettings.getCurrentState()
|
||||
llamaSettings.huggingFaceModel = HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3
|
||||
val conversation = Conversation()
|
||||
conversation.model = "CODE_LLAMA_7B_Q3"
|
||||
conversation.clientCode = "llama.chat.completion"
|
||||
val settings = GeneralSettings.getInstance()
|
||||
|
||||
settings.sync(conversation)
|
||||
|
||||
assertThat(settings.state.selectedService).isEqualTo(ServiceType.LLAMA_CPP)
|
||||
assertThat(llamaSettings.huggingFaceModel).isEqualTo(HuggingFaceModel.CODE_LLAMA_7B_Q3)
|
||||
assertThat(llamaSettings.isUseCustomModel).isFalse()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,410 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.chat
|
||||
|
||||
import ee.carlrobert.codegpt.CodeGPTKeys
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.ReferencedFile
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
|
||||
import ee.carlrobert.codegpt.completions.ConversationType
|
||||
import ee.carlrobert.codegpt.completions.HuggingFaceModel
|
||||
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
|
||||
import ee.carlrobert.llm.client.http.RequestEntity
|
||||
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.e
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.jsonArray
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.jsonMap
|
||||
import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse
|
||||
import org.apache.http.HttpHeaders
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
import java.io.IOException
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
import java.util.Base64
|
||||
import java.util.Objects
|
||||
|
||||
class ChatToolWindowTabPanelTest : IntegrationTest() {
|
||||
|
||||
fun testSendingOpenAIMessage() {
|
||||
useOpenAIService()
|
||||
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
|
||||
val message = Message("Hello!")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val panel = ChatToolWindowTabPanel(project, conversation)
|
||||
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
|
||||
mapOf("role" to "user", "content" to "Hello!")))
|
||||
listOf(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
|
||||
panel.sendMessage(message)
|
||||
|
||||
waitExpecting {
|
||||
val messages = conversation.messages
|
||||
messages.isNotEmpty() && "Hello!" == messages[0].response
|
||||
}
|
||||
val encodingManager = EncodingManager.getInstance()
|
||||
assertThat(panel.tokenDetails).extracting(
|
||||
"systemPromptTokens",
|
||||
"conversationTokens",
|
||||
"userPromptTokens",
|
||||
"highlightedTokens")
|
||||
.containsExactly(
|
||||
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
|
||||
encodingManager.countTokens(message.prompt),
|
||||
0,
|
||||
0)
|
||||
assertThat(panel.conversation)
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.id,
|
||||
conversation.model,
|
||||
conversation.clientCode,
|
||||
false)
|
||||
val messages = panel.conversation.messages
|
||||
assertThat(messages).hasSize(1)
|
||||
assertThat(messages[0])
|
||||
.extracting("id", "prompt", "response")
|
||||
.containsExactly(message.id, message.prompt, message.response)
|
||||
}
|
||||
|
||||
fun testSendingOpenAIMessageWithReferencedContext() {
|
||||
project.putUserData(CodeGPTKeys.SELECTED_FILES, listOf(
|
||||
ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"),
|
||||
ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
|
||||
ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")))
|
||||
useOpenAIService()
|
||||
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
|
||||
val message = Message("TEST_MESSAGE")
|
||||
message.userMessage = "TEST_MESSAGE"
|
||||
message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val panel = ChatToolWindowTabPanel(project, conversation)
|
||||
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
|
||||
mapOf("role" to "user", "content" to """
|
||||
Use the following context to answer question at the end:
|
||||
|
||||
File Path: TEST_FILE_PATH_1
|
||||
File Content:
|
||||
```TEST_FILE_NAME_1
|
||||
TEST_FILE_CONTENT_1
|
||||
```
|
||||
|
||||
File Path: TEST_FILE_PATH_2
|
||||
File Content:
|
||||
```TEST_FILE_NAME_2
|
||||
TEST_FILE_CONTENT_2
|
||||
```
|
||||
|
||||
File Path: TEST_FILE_PATH_3
|
||||
File Content:
|
||||
```TEST_FILE_NAME_3
|
||||
TEST_FILE_CONTENT_3
|
||||
```
|
||||
|
||||
Question: TEST_MESSAGE""".trimIndent())))
|
||||
listOf(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
|
||||
panel.sendMessage(message)
|
||||
|
||||
waitExpecting {
|
||||
val messages = conversation.messages
|
||||
messages.isNotEmpty() && "Hello!" == messages[0].response
|
||||
}
|
||||
val encodingManager = EncodingManager.getInstance()
|
||||
assertThat(panel.tokenDetails).extracting(
|
||||
"systemPromptTokens",
|
||||
"conversationTokens",
|
||||
"userPromptTokens",
|
||||
"highlightedTokens")
|
||||
.containsExactly(
|
||||
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
|
||||
encodingManager.countTokens(message.prompt),
|
||||
0,
|
||||
0)
|
||||
assertThat(panel.conversation)
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.id,
|
||||
conversation.model,
|
||||
conversation.clientCode,
|
||||
false)
|
||||
val messages = panel.conversation.messages
|
||||
assertThat(messages).hasSize(1)
|
||||
assertThat(messages[0])
|
||||
.extracting("id", "prompt", "response", "referencedFilePaths")
|
||||
.containsExactly(
|
||||
message.id,
|
||||
message.prompt,
|
||||
message.response,
|
||||
listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"))
|
||||
}
|
||||
|
||||
fun testSendingOpenAIMessageWithImage() {
|
||||
val testImagePath = Objects.requireNonNull(javaClass.getResource("/images/test-image.png")).path
|
||||
project.putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath)
|
||||
useOpenAIService("gpt-4-vision-preview")
|
||||
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
|
||||
val message = Message("TEST_MESSAGE")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val panel = ChatToolWindowTabPanel(project, conversation)
|
||||
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
try {
|
||||
val testImageUrl = ("data:image/png;base64,"
|
||||
+ Base64.getEncoder().encodeToString(Files.readAllBytes(Path.of(testImagePath))))
|
||||
assertThat(request.body)
|
||||
.extracting("model", "messages")
|
||||
.containsExactly(
|
||||
"gpt-4-vision-preview",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
|
||||
mapOf("role" to "user", "content" to listOf(
|
||||
mapOf(
|
||||
"type" to "image_url",
|
||||
"image_url" to mapOf("url" to testImageUrl)),
|
||||
mapOf("type" to "text", "text" to "TEST_MESSAGE")
|
||||
))))
|
||||
} catch (e: IOException) {
|
||||
throw RuntimeException(e)
|
||||
}
|
||||
listOf<String?>(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
|
||||
panel.sendMessage(message)
|
||||
|
||||
waitExpecting {
|
||||
val messages = conversation.messages
|
||||
messages.isNotEmpty() && "Hello!" == messages[0].response
|
||||
}
|
||||
val encodingManager = EncodingManager.getInstance()
|
||||
assertThat(panel.tokenDetails).extracting(
|
||||
"systemPromptTokens",
|
||||
"conversationTokens",
|
||||
"userPromptTokens",
|
||||
"highlightedTokens")
|
||||
.containsExactly(
|
||||
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
|
||||
encodingManager.countTokens(message.prompt),
|
||||
0,
|
||||
0)
|
||||
assertThat(panel.conversation)
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.id,
|
||||
conversation.model,
|
||||
conversation.clientCode,
|
||||
false)
|
||||
val messages = panel.conversation.messages
|
||||
assertThat(messages).hasSize(1)
|
||||
assertThat(messages[0])
|
||||
.extracting("id", "prompt", "response", "imageFilePath")
|
||||
.containsExactly(
|
||||
message.id,
|
||||
message.prompt,
|
||||
message.response,
|
||||
message.imageFilePath)
|
||||
}
|
||||
|
||||
fun testFixCompileErrorsWithOpenAIService() {
|
||||
project.putUserData(
|
||||
CodeGPTKeys.SELECTED_FILES, listOf(
|
||||
ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"),
|
||||
ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
|
||||
ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")))
|
||||
useOpenAIService()
|
||||
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
|
||||
val message = Message("TEST_MESSAGE")
|
||||
message.userMessage = "TEST_MESSAGE"
|
||||
message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val panel = ChatToolWindowTabPanel(project, conversation)
|
||||
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"model",
|
||||
"messages")
|
||||
.containsExactly(
|
||||
"gpt-4",
|
||||
listOf(
|
||||
mapOf("role" to "system", "content" to FIX_COMPILE_ERRORS_SYSTEM_PROMPT),
|
||||
mapOf("role" to "user", "content" to """
|
||||
Use the following context to answer question at the end:
|
||||
|
||||
File Path: TEST_FILE_PATH_1
|
||||
File Content:
|
||||
```TEST_FILE_NAME_1
|
||||
TEST_FILE_CONTENT_1
|
||||
```
|
||||
|
||||
File Path: TEST_FILE_PATH_2
|
||||
File Content:
|
||||
```TEST_FILE_NAME_2
|
||||
TEST_FILE_CONTENT_2
|
||||
```
|
||||
|
||||
File Path: TEST_FILE_PATH_3
|
||||
File Content:
|
||||
```TEST_FILE_NAME_3
|
||||
TEST_FILE_CONTENT_3
|
||||
```
|
||||
|
||||
Question: TEST_MESSAGE""".trimIndent())))
|
||||
listOf(
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
|
||||
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
|
||||
})
|
||||
|
||||
panel.sendMessage(message, ConversationType.FIX_COMPILE_ERRORS)
|
||||
|
||||
waitExpecting {
|
||||
val messages = conversation.messages
|
||||
messages.isNotEmpty() && "Hello!" == messages[0].response
|
||||
}
|
||||
val encodingManager = EncodingManager.getInstance()
|
||||
assertThat(panel.tokenDetails).extracting(
|
||||
"systemPromptTokens",
|
||||
"conversationTokens",
|
||||
"userPromptTokens",
|
||||
"highlightedTokens")
|
||||
.containsExactly(
|
||||
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
|
||||
encodingManager.countTokens(message.prompt),
|
||||
0,
|
||||
0)
|
||||
assertThat(panel.conversation)
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.id,
|
||||
conversation.model,
|
||||
conversation.clientCode,
|
||||
false)
|
||||
val messages = panel.conversation.messages
|
||||
assertThat(messages).hasSize(1)
|
||||
assertThat(messages[0])
|
||||
.extracting("id", "prompt", "response", "referencedFilePaths")
|
||||
.containsExactly(
|
||||
message.id,
|
||||
message.prompt,
|
||||
message.response,
|
||||
listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"))
|
||||
}
|
||||
|
||||
fun testSendingLlamaMessage() {
|
||||
useLlamaService()
|
||||
val configurationState = ConfigurationSettings.getCurrentState()
|
||||
configurationState.systemPrompt = COMPLETION_SYSTEM_PROMPT
|
||||
configurationState.maxTokens = 1000
|
||||
configurationState.temperature = 0.1
|
||||
val llamaSettings = LlamaSettings.getCurrentState()
|
||||
llamaSettings.isUseCustomModel = false
|
||||
llamaSettings.huggingFaceModel = HuggingFaceModel.CODE_LLAMA_7B_Q4
|
||||
llamaSettings.topK = 30
|
||||
llamaSettings.topP = 0.8
|
||||
llamaSettings.minP = 0.03
|
||||
llamaSettings.repeatPenalty = 1.3
|
||||
val message = Message("TEST_PROMPT")
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
val panel = ChatToolWindowTabPanel(project, conversation)
|
||||
expectLlama(StreamHttpExchange { request: RequestEntity ->
|
||||
assertThat(request.uri.path).isEqualTo("/completion")
|
||||
assertThat(request.body)
|
||||
.extracting(
|
||||
"prompt",
|
||||
"n_predict",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_k",
|
||||
"top_p",
|
||||
"min_p",
|
||||
"repeat_penalty")
|
||||
.containsExactly(
|
||||
LLAMA.buildPrompt(
|
||||
COMPLETION_SYSTEM_PROMPT,
|
||||
"TEST_PROMPT",
|
||||
conversation.messages),
|
||||
configurationState.maxTokens,
|
||||
true,
|
||||
configurationState.temperature,
|
||||
llamaSettings.topK,
|
||||
llamaSettings.topP,
|
||||
llamaSettings.minP,
|
||||
llamaSettings.repeatPenalty)
|
||||
listOf<String?>(
|
||||
jsonMapResponse("content", "Hel"),
|
||||
jsonMapResponse("content", "lo!"),
|
||||
jsonMapResponse(
|
||||
e("content", ""),
|
||||
e("stop", true)))
|
||||
})
|
||||
|
||||
panel.sendMessage(message, ConversationType.DEFAULT)
|
||||
|
||||
waitExpecting {
|
||||
val messages = conversation.messages
|
||||
messages.isNotEmpty() && "Hello!" == messages[0].response
|
||||
}
|
||||
assertThat(panel.conversation)
|
||||
.isNotNull()
|
||||
.extracting("id", "model", "clientCode", "discardTokenLimit")
|
||||
.containsExactly(
|
||||
conversation.id,
|
||||
conversation.model,
|
||||
conversation.clientCode,
|
||||
false)
|
||||
val messages = panel.conversation.messages
|
||||
assertThat(messages).hasSize(1)
|
||||
assertThat(messages[0])
|
||||
.extracting("id", "prompt", "response")
|
||||
.containsExactly(message.id, message.prompt, message.response)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package ee.carlrobert.codegpt.toolwindow.chat
|
||||
|
||||
import com.intellij.openapi.util.Disposer
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase
|
||||
import ee.carlrobert.codegpt.conversations.ConversationService
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
|
||||
class ChatToolWindowTabbedPaneTest : BasePlatformTestCase() {
|
||||
|
||||
fun testClearAllTabs() {
|
||||
val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())
|
||||
tabbedPane.addNewTab(createNewTabPanel())
|
||||
|
||||
tabbedPane.clearAll()
|
||||
|
||||
assertThat(tabbedPane.activeTabMapping).isEmpty()
|
||||
}
|
||||
|
||||
|
||||
fun testAddingNewTabs() {
|
||||
val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())
|
||||
|
||||
tabbedPane.addNewTab(createNewTabPanel())
|
||||
tabbedPane.addNewTab(createNewTabPanel())
|
||||
tabbedPane.addNewTab(createNewTabPanel())
|
||||
|
||||
assertThat(tabbedPane.activeTabMapping.keys)
|
||||
.containsExactly("Chat 1", "Chat 2", "Chat 3")
|
||||
}
|
||||
|
||||
fun testResetCurrentlyActiveTabPanel() {
|
||||
val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())
|
||||
val conversation = ConversationService.getInstance().startConversation()
|
||||
conversation.addMessage(Message("TEST_PROMPT", "TEST_RESPONSE"))
|
||||
tabbedPane.addNewTab(ChatToolWindowTabPanel(project, conversation))
|
||||
|
||||
tabbedPane.resetCurrentlyActiveTabPanel(project)
|
||||
|
||||
val tabPanel = tabbedPane.activeTabMapping["Chat 1"]
|
||||
assertThat(tabPanel!!.conversation.messages).isEmpty()
|
||||
}
|
||||
|
||||
private fun createNewTabPanel(): ChatToolWindowTabPanel {
|
||||
return ChatToolWindowTabPanel(
|
||||
project,
|
||||
ConversationService.getInstance().startConversation()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
package ee.carlrobert.codegpt.util;
|
||||
package ee.carlrobert.codegpt.util
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class MarkdownUtilTest {
|
||||
class MarkdownUtilTest {
|
||||
|
||||
@Test
|
||||
public void shouldExtractMarkdownCodeBlocks() {
|
||||
String testInput = """
|
||||
fun shouldExtractMarkdownCodeBlocks() {
|
||||
val testInput = """
|
||||
**C++ Code Block**
|
||||
```cpp
|
||||
#include <iostream>
|
||||
|
|
@ -29,60 +28,68 @@ public class MarkdownUtilTest {
|
|||
```
|
||||
1. We define a **public class** called **Main**.
|
||||
2. We define the **main** method which is the entry point of the program.
|
||||
""";
|
||||
|
||||
var result = MarkdownUtil.splitCodeBlocks(testInput);
|
||||
""".trimIndent()
|
||||
|
||||
val result = MarkdownUtil.splitCodeBlocks(testInput)
|
||||
|
||||
assertThat(result).containsExactly("""
|
||||
**C++ Code Block**
|
||||
""", """
|
||||
|
||||
""".trimIndent(), """
|
||||
```cpp
|
||||
#include <iostream>
|
||||
|
||||
int main() {
|
||||
return 0;
|
||||
}
|
||||
```""", """
|
||||
```""".trimIndent(), """
|
||||
|
||||
1. We include the **iostream** header file.
|
||||
2. We define the main function.
|
||||
|
||||
**Java Code Block**
|
||||
""", """
|
||||
|
||||
""".trimIndent(), """
|
||||
```java
|
||||
public class Main {
|
||||
public static void main(String[] args) {
|
||||
}
|
||||
}
|
||||
```""", """
|
||||
```""".trimIndent(), """
|
||||
|
||||
1. We define a **public class** called **Main**.
|
||||
2. We define the **main** method which is the entry point of the program.
|
||||
""");
|
||||
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldExtractMarkdownWithoutCode() {
|
||||
String testInput = """
|
||||
fun shouldExtractMarkdownWithoutCode() {
|
||||
val testInput = """
|
||||
**C++ Code Block**
|
||||
1. We include the **iostream** header file.
|
||||
2. We define the main function.
|
||||
|
||||
""";
|
||||
|
||||
var result = MarkdownUtil.splitCodeBlocks(testInput);
|
||||
|
||||
""".trimIndent()
|
||||
|
||||
val result = MarkdownUtil.splitCodeBlocks(testInput)
|
||||
|
||||
assertThat(result).containsExactly("""
|
||||
**C++ Code Block**
|
||||
1. We include the **iostream** header file.
|
||||
2. We define the main function.
|
||||
|
||||
""");
|
||||
|
||||
|
||||
""".trimIndent())
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldExtractMarkdownCodeOnly() {
|
||||
String testInput = """
|
||||
fun shouldExtractMarkdownCodeOnly() {
|
||||
val testInput = """
|
||||
```cpp
|
||||
#include <iostream>
|
||||
|
||||
|
|
@ -96,9 +103,10 @@ public class MarkdownUtilTest {
|
|||
}
|
||||
}
|
||||
```
|
||||
""";
|
||||
|
||||
var result = MarkdownUtil.splitCodeBlocks(testInput);
|
||||
""".trimIndent()
|
||||
|
||||
val result = MarkdownUtil.splitCodeBlocks(testInput)
|
||||
|
||||
assertThat(result).containsExactly("""
|
||||
```cpp
|
||||
|
|
@ -107,12 +115,12 @@ public class MarkdownUtilTest {
|
|||
int main() {
|
||||
return 0;
|
||||
}
|
||||
```""", """
|
||||
```""".trimIndent(), """
|
||||
```java
|
||||
public class Main {
|
||||
public static void main(String[] args) {
|
||||
}
|
||||
}
|
||||
```""");
|
||||
```""".trimIndent())
|
||||
}
|
||||
}
|
||||
33
src/test/kotlin/testsupport/IntegrationTest.kt
Normal file
33
src/test/kotlin/testsupport/IntegrationTest.kt
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package testsupport
|
||||
|
||||
import com.intellij.openapi.util.Key
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase
|
||||
import ee.carlrobert.codegpt.CodeGPTKeys
|
||||
import ee.carlrobert.llm.client.mixin.ExternalServiceTestMixin
|
||||
import testsupport.mixin.ShortcutsTestMixin
|
||||
|
||||
open class IntegrationTest : BasePlatformTestCase(), ExternalServiceTestMixin, ShortcutsTestMixin {
|
||||
|
||||
@Throws(Exception::class)
|
||||
override fun tearDown() {
|
||||
ExternalServiceTestMixin.clearAll()
|
||||
clearKeys()
|
||||
super.tearDown()
|
||||
}
|
||||
|
||||
private fun clearKeys() {
|
||||
putUserData(CodeGPTKeys.SELECTED_FILES, emptyList())
|
||||
putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, "")
|
||||
putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, "")
|
||||
}
|
||||
|
||||
private fun <T> putUserData(key: Key<T>, value: T) {
|
||||
project.putUserData(key, value)
|
||||
}
|
||||
|
||||
companion object {
|
||||
init {
|
||||
ExternalServiceTestMixin.init()
|
||||
}
|
||||
}
|
||||
}
|
||||
52
src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt
Normal file
52
src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package testsupport.mixin
|
||||
|
||||
import com.intellij.testFramework.PlatformTestUtil
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_OPENAI_API_KEY
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
|
||||
import ee.carlrobert.codegpt.settings.GeneralSettings
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings
|
||||
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
|
||||
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
|
||||
import java.util.function.BooleanSupplier
|
||||
|
||||
interface ShortcutsTestMixin {
|
||||
|
||||
fun useOpenAIService() {
|
||||
useOpenAIService("gpt-4")
|
||||
}
|
||||
|
||||
fun useOpenAIService(model: String? = "gpt-4") {
|
||||
GeneralSettings.getCurrentState().selectedService = ServiceType.OPENAI
|
||||
setCredential(OPENAI_API_KEY, "TEST_API_KEY")
|
||||
OpenAISettings.getCurrentState().model = model
|
||||
}
|
||||
|
||||
fun useAzureService() {
|
||||
GeneralSettings.getCurrentState().selectedService = ServiceType.AZURE
|
||||
setCredential(AZURE_OPENAI_API_KEY, "TEST_API_KEY")
|
||||
val azureSettings = AzureSettings.getCurrentState()
|
||||
azureSettings.resourceName = "TEST_RESOURCE_NAME"
|
||||
azureSettings.apiVersion = "TEST_API_VERSION"
|
||||
azureSettings.deploymentId = "TEST_DEPLOYMENT_ID"
|
||||
}
|
||||
|
||||
fun useYouService() {
|
||||
GeneralSettings.getCurrentState().selectedService = ServiceType.YOU
|
||||
}
|
||||
|
||||
fun useLlamaService() {
|
||||
GeneralSettings.getCurrentState().selectedService = ServiceType.LLAMA_CPP
|
||||
LlamaSettings.getCurrentState().serverPort = null
|
||||
}
|
||||
|
||||
fun waitExpecting(condition: BooleanSupplier?) {
|
||||
PlatformTestUtil.waitWithEventsDispatching(
|
||||
"Waiting for message response timed out or did not meet expected conditions",
|
||||
condition!!,
|
||||
5
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue