diff --git a/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java b/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java deleted file mode 100644 index bdd390a4..00000000 --- a/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java +++ /dev/null @@ -1,43 +0,0 @@ -package ee.carlrobert.codegpt.codecompletions; - -import static ee.carlrobert.codegpt.CodeGPTKeys.PREVIOUS_INLAY_TEXT; -import static ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate.LLAMA; -import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent; -import static ee.carlrobert.llm.client.util.JSONUtil.e; -import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse; -import static org.assertj.core.api.Assertions.assertThat; - -import com.intellij.openapi.editor.VisualPosition; -import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; -import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange; -import java.util.List; -import testsupport.IntegrationTest; - -public class CodeCompletionServiceTest extends IntegrationTest { - - private final VisualPosition cursorPosition = new VisualPosition(3, 0); - - public void testFetchCodeCompletionLlama() { - useLlamaService(); - LlamaSettings.getCurrentState().setCodeCompletionsEnabled(true); - myFixture.configureByText( - "CompletionTest.java", - getResourceContent("/codecompletions/code-completion-file.txt")); - myFixture.getEditor().getCaretModel().moveToVisualPosition(cursorPosition); - var expectedCompletion = "TEST_OUTPUT"; - var prefix = "z".repeat(245) + "\n[INPUT]\nc"; // 128 tokens - var suffix = "\n[\\INPUT]\n" + "z".repeat(247); // 128 tokens - expectLlama((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/completion"); - assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getBody()) - .extracting("prompt") - .isEqualTo(LLAMA.buildPrompt(prefix, suffix)); - return List.of(jsonMapResponse(e("content", expectedCompletion), e("stop", true))); - }); - - myFixture.type('c'); - - waitExpecting(() -> "TEST_OUTPUT".equals(PREVIOUS_INLAY_TEXT.get(myFixture.getEditor()))); - } -} diff --git a/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java b/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java deleted file mode 100644 index 21cb1054..00000000 --- a/src/test/java/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.java +++ /dev/null @@ -1,154 +0,0 @@ -package ee.carlrobert.codegpt.completions; - -import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.groups.Tuple.tuple; - -import ee.carlrobert.codegpt.conversations.ConversationService; -import ee.carlrobert.codegpt.conversations.message.Message; -import ee.carlrobert.codegpt.credentials.CredentialsStore; -import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey; -import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; -import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel; -import testsupport.IntegrationTest; - -public class CompletionRequestProviderTest extends IntegrationTest { - - public void testChatCompletionRequestWithSystemPromptOverride() { - CredentialsStore.INSTANCE.setCredential(CredentialKey.OPENAI_API_KEY, "TEST_API_KEY"); - ConfigurationSettings.getCurrentState().setSystemPrompt("TEST_SYSTEM_PROMPT"); - var conversation = ConversationService.getInstance().startConversation(); - var firstMessage = createDummyMessage(500); - var secondMessage = createDummyMessage(250); - conversation.addMessage(firstMessage); - conversation.addMessage(secondMessage); - - var request = new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.getCode(), - new CallParameters( - conversation, - ConversationType.DEFAULT, - new Message("TEST_CHAT_COMPLETION_PROMPT"), - false)); - - assertThat(request.getMessages()) - .extracting("role", "content") - .containsExactly( - tuple("system", "TEST_SYSTEM_PROMPT"), - tuple("user", "TEST_PROMPT"), - tuple("assistant", firstMessage.getResponse()), - tuple("user", "TEST_PROMPT"), - tuple("assistant", secondMessage.getResponse()), - tuple("user", "TEST_CHAT_COMPLETION_PROMPT")); - } - - public void testChatCompletionRequestWithoutSystemPromptOverride() { - var conversation = ConversationService.getInstance().startConversation(); - var firstMessage = createDummyMessage(500); - var secondMessage = createDummyMessage(250); - conversation.addMessage(firstMessage); - conversation.addMessage(secondMessage); - - var request = new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.getCode(), - new CallParameters( - conversation, - ConversationType.DEFAULT, - new Message("TEST_CHAT_COMPLETION_PROMPT"), - false)); - - assertThat(request.getMessages()) - .extracting("role", "content") - .containsExactly( - tuple("system", COMPLETION_SYSTEM_PROMPT), - tuple("user", "TEST_PROMPT"), - tuple("assistant", firstMessage.getResponse()), - tuple("user", "TEST_PROMPT"), - tuple("assistant", secondMessage.getResponse()), - tuple("user", "TEST_CHAT_COMPLETION_PROMPT")); - } - - public void testChatCompletionRequestRetry() { - ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT); - var conversation = ConversationService.getInstance().startConversation(); - var firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500); - var secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250); - conversation.addMessage(firstMessage); - conversation.addMessage(secondMessage); - - var request = new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.getCode(), - new CallParameters( - conversation, - ConversationType.DEFAULT, - secondMessage, - true)); - - assertThat(request.getMessages()) - .extracting("role", "content") - .containsExactly( - tuple("system", COMPLETION_SYSTEM_PROMPT), - tuple("user", "FIRST_TEST_PROMPT"), - tuple("assistant", firstMessage.getResponse()), - tuple("user", "SECOND_TEST_PROMPT")); - } - - public void testReducedChatCompletionRequest() { - var conversation = ConversationService.getInstance().startConversation(); - conversation.addMessage(createDummyMessage(50)); - conversation.addMessage(createDummyMessage(100)); - conversation.addMessage(createDummyMessage(150)); - conversation.addMessage(createDummyMessage(1000)); - var remainingMessage = createDummyMessage(2000); - conversation.addMessage(remainingMessage); - conversation.discardTokenLimits(); - - var request = new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.getCode(), - new CallParameters( - conversation, - ConversationType.DEFAULT, - new Message("TEST_CHAT_COMPLETION_PROMPT"), - false)); - - assertThat(request.getMessages()) - .extracting("role", "content") - .containsExactly( - tuple("system", COMPLETION_SYSTEM_PROMPT), - tuple("user", "TEST_PROMPT"), - tuple("assistant", remainingMessage.getResponse()), - tuple("user", "TEST_CHAT_COMPLETION_PROMPT")); - } - - public void testTotalUsageExceededException() { - var conversation = ConversationService.getInstance().startConversation(); - conversation.addMessage(createDummyMessage(1500)); - conversation.addMessage(createDummyMessage(1500)); - conversation.addMessage(createDummyMessage(1500)); - - assertThrows(TotalUsageExceededException.class, - () -> new CompletionRequestProvider(conversation) - .buildOpenAIChatCompletionRequest( - OpenAIChatCompletionModel.GPT_3_5.getCode(), - new CallParameters( - conversation, - ConversationType.DEFAULT, - createDummyMessage(100), - false))); - } - - private Message createDummyMessage(int tokenSize) { - return createDummyMessage("TEST_PROMPT", tokenSize); - } - - private Message createDummyMessage(String prompt, int tokenSize) { - var message = new Message(prompt); - // 'zz' = 1 token, prompt = 6 tokens, 7 tokens per message (GPT-3), - message.setResponse("zz".repeat((tokenSize) - 6 - 7)); - return message; - } -} diff --git a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java deleted file mode 100644 index 18fa3020..00000000 --- a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java +++ /dev/null @@ -1,183 +0,0 @@ -package ee.carlrobert.codegpt.completions; - -import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT; -import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA; -import static ee.carlrobert.llm.client.util.JSONUtil.e; -import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray; -import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap; -import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse; -import static org.apache.http.HttpHeaders.AUTHORIZATION; -import static org.assertj.core.api.Assertions.assertThat; - -import ee.carlrobert.codegpt.CodeGPTPlugin; -import ee.carlrobert.codegpt.conversations.ConversationService; -import ee.carlrobert.codegpt.conversations.message.Message; -import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; -import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange; -import java.util.List; -import java.util.Map; -import testsupport.IntegrationTest; - -public class DefaultCompletionRequestHandlerTest extends IntegrationTest { - - public void testOpenAIChatCompletionCall() { - useOpenAIService(); - var message = new Message("TEST_PROMPT"); - var conversation = ConversationService.getInstance().startConversation(); - var requestHandler = new CompletionRequestHandler(getRequestEventListener(message)); - expectOpenAI((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions"); - assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); - assertThat(request.getBody()) - .extracting( - "model", - "messages") - .containsExactly( - "gpt-4", - List.of( - Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT), - Map.of("role", "user", "content", "TEST_PROMPT"))); - - return List.of( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))); - }); - - requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false)); - - waitExpecting(() -> "Hello!".equals(message.getResponse())); - } - - public void testAzureChatCompletionCall() { - useAzureService(); - var conversationService = ConversationService.getInstance(); - var prevMessage = new Message("TEST_PREV_PROMPT"); - prevMessage.setResponse("TEST_PREV_RESPONSE"); - var conversation = conversationService.startConversation(); - conversation.addMessage(prevMessage); - conversationService.saveConversation(conversation); - expectAzure((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo( - "/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions"); - assertThat(request.getUri().getQuery()).isEqualTo("api-version=TEST_API_VERSION"); - assertThat(request.getHeaders().get("Api-key").get(0)).isEqualTo("TEST_API_KEY"); - assertThat(request.getHeaders().get("X-llm-application-tag").get(0)).isEqualTo("codegpt"); - assertThat(request.getBody()) - .extracting("messages") - .isEqualTo( - List.of( - Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT), - Map.of("role", "user", "content", "TEST_PREV_PROMPT"), - Map.of("role", "assistant", "content", "TEST_PREV_RESPONSE"), - Map.of("role", "user", "content", "TEST_PROMPT"))); - return List.of( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))); - }); - var message = new Message("TEST_PROMPT"); - var requestHandler = new CompletionRequestHandler(getRequestEventListener(message)); - - requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false)); - - waitExpecting(() -> "Hello!".equals(message.getResponse())); - } - - public void testYouChatCompletionCall() { - useYouService(); - var message = new Message("TEST_PROMPT"); - var conversation = ConversationService.getInstance().startConversation(); - conversation.addMessage(new Message("Ping", "Pong")); - var requestHandler = new CompletionRequestHandler(getRequestEventListener(message)); - expectYou((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch"); - assertThat(request.getMethod()).isEqualTo("GET"); - assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch"); - assertThat(request.getUri().getQuery()).isEqualTo( - "q=TEST_PROMPT&" - + "page=1&" - + "cfr=CodeGPT&" - + "count=10&" - + "safeSearch=WebPages,Translations,TimeZone,Computation,RelatedSearches&" - + "domain=youchat&" - + "selectedChatMode=default&" - + "chat=[{\"question\":\"Ping\",\"answer\":\"Pong\"}]&" - + "utm_source=ide&" - + "utm_medium=jetbrains&" - + "utm_campaign=" + CodeGPTPlugin.getVersion() + "&" - + "utm_content=CodeGPT"); - assertThat(request.getHeaders()) - .flatExtracting("Accept", "Connection", "User-agent", "Cookie") - .containsExactly( - "text/event-stream", - "Keep-Alive", - "youide CodeGPT", - "safesearch_guest=Moderate; " - + "youpro_subscription=true; " - + "you_subscription=free; " - + "stytch_session=; " - + "ydc_stytch_session=; " - + "stytch_session_jwt=; " - + "ydc_stytch_session_jwt=; " - + "eg4=false; " - + "__cf_bm=aN2b3pQMH8XADeMB7bg9s1bJ_bfXBcCHophfOGRg6g0-1693601599-0-" - + "AWIt5Mr4Y3xQI4mIJ1lSf4+vijWKDobrty8OopDeBxY+NABe0MRFidF3dCUoWjRt8" - + "SVMvBZPI3zkOgcRs7Mz3yazd7f7c58HwW5Xg9jdBjNg;"); - return List.of( - jsonMapResponse("youChatToken", "Hel"), - jsonMapResponse("youChatToken", "lo"), - jsonMapResponse("youChatToken", "!")); - }); - - requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false)); - - waitExpecting(() -> "Hello!".equals(message.getResponse())); - } - - public void testLlamaChatCompletionCall() { - useLlamaService(); - ConfigurationSettings.getCurrentState().setMaxTokens(99); - var message = new Message("TEST_PROMPT"); - var conversation = ConversationService.getInstance().startConversation(); - conversation.addMessage(new Message("Ping", "Pong")); - var requestHandler = new CompletionRequestHandler(getRequestEventListener(message)); - expectLlama((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/completion"); - assertThat(request.getBody()) - .extracting( - "prompt", - "n_predict", - "stream") - .containsExactly( - LLAMA.buildPrompt( - COMPLETION_SYSTEM_PROMPT, - "TEST_PROMPT", - conversation.getMessages()), - 99, - true); - return List.of( - jsonMapResponse("content", "Hel"), - jsonMapResponse("content", "lo!"), - jsonMapResponse( - e("content", ""), - e("stop", true))); - }); - - requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false)); - - waitExpecting(() -> "Hello!".equals(message.getResponse())); - } - - private CompletionResponseEventListener getRequestEventListener(Message message) { - return new CompletionResponseEventListener() { - @Override - public void handleCompleted(String fullMessage, CallParameters callParameters) { - message.setResponse(fullMessage); - } - }; - } -} diff --git a/src/test/java/ee/carlrobert/codegpt/completions/PromptTemplateTest.java b/src/test/java/ee/carlrobert/codegpt/completions/PromptTemplateTest.java deleted file mode 100644 index 4fa25a79..00000000 --- a/src/test/java/ee/carlrobert/codegpt/completions/PromptTemplateTest.java +++ /dev/null @@ -1,145 +0,0 @@ -package ee.carlrobert.codegpt.completions; - -import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA; -import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML; -import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA; -import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA; -import static org.assertj.core.api.Assertions.assertThat; - -import ee.carlrobert.codegpt.conversations.message.Message; -import java.util.List; -import org.junit.Test; - -public class PromptTemplateTest { - - private static final String SYSTEM_PROMPT = "TEST_SYSTEM_PROMPT"; - private static final String USER_PROMPT = "TEST_USER_PROMPT"; - private static final List HISTORY = List.of( - new Message("TEST_PREV_PROMPT_1", "TEST_PREV_RESPONSE_1"), - new Message("TEST_PREV_PROMPT_2", "TEST_PREV_RESPONSE_2")); - - @Test - public void shouldBuildLlamaPromptWithHistory() { - var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY); - - assertThat(prompt).isEqualTo(""" - <>TEST_SYSTEM_PROMPT<> - [INST]TEST_PREV_PROMPT_1[/INST] - TEST_PREV_RESPONSE_1 - [INST]TEST_PREV_PROMPT_2[/INST] - TEST_PREV_RESPONSE_2 - [INST]TEST_USER_PROMPT[/INST]"""); - } - - @Test - public void shouldBuildLlamaPromptWithoutHistory() { - var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of()); - - assertThat(prompt).isEqualTo(""" - <>TEST_SYSTEM_PROMPT<> - [INST]TEST_USER_PROMPT[/INST]"""); - } - - @Test - public void shouldBuildAlpacaPromptWithHistory() { - var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY); - - assertThat(prompt).isEqualTo(""" - Below is an instruction that describes a task. \ - Write a response that appropriately completes the request. - - ### Instruction - TEST_PREV_PROMPT_1 - - ### Response: - TEST_PREV_RESPONSE_1 - - ### Instruction - TEST_PREV_PROMPT_2 - - ### Response: - TEST_PREV_RESPONSE_2 - - ### Instruction - TEST_USER_PROMPT - - ### Response: - """); - } - - @Test - public void shouldBuildAlpacaPromptWithoutHistory() { - var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of()); - - assertThat(prompt).isEqualTo(""" - Below is an instruction that describes a task. \ - Write a response that appropriately completes the request. - - ### Instruction - TEST_USER_PROMPT - - ### Response: - """); - } - - @Test - public void shouldBuildChatMLPromptWithHistory() { - var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY); - - assertThat(prompt).isEqualTo(""" - <|im_start|>system - TEST_SYSTEM_PROMPT<|im_end|> - <|im_start|>user - TEST_PREV_PROMPT_1<|im_end|> - <|im_start|>assistant - TEST_PREV_RESPONSE_1<|im_end|> - <|im_start|>user - TEST_PREV_PROMPT_2<|im_end|> - <|im_start|>assistant - TEST_PREV_RESPONSE_2<|im_end|> - <|im_start|>user - TEST_USER_PROMPT<|im_end|>""" - ); - } - - @Test - public void shouldBuildChatMLPromptWithoutHistory() { - var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of()); - - assertThat(prompt).isEqualTo(""" - <|im_start|>system - TEST_SYSTEM_PROMPT<|im_end|> - <|im_start|>user - TEST_USER_PROMPT<|im_end|>"""); - } - - @Test - public void shouldBuildToRAPromptWithHistory() { - var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY); - - assertThat(prompt).isEqualTo(""" - <|user|> - TEST_PREV_PROMPT_1 - <|assistant|> - TEST_PREV_RESPONSE_1 - <|user|> - TEST_PREV_PROMPT_2 - <|assistant|> - TEST_PREV_RESPONSE_2 - <|user|> - TEST_USER_PROMPT - <|assistant|>""" - ); - } - - @Test - public void shouldBuildToRAPromptWithoutHistory() { - var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of()); - - assertThat(prompt).isEqualTo(""" - <|user|> - TEST_USER_PROMPT - <|assistant|>""" - ); - } -} diff --git a/src/test/java/ee/carlrobert/codegpt/conversations/ConversationsStateTest.java b/src/test/java/ee/carlrobert/codegpt/conversations/ConversationsStateTest.java deleted file mode 100644 index 7eb9c3e9..00000000 --- a/src/test/java/ee/carlrobert/codegpt/conversations/ConversationsStateTest.java +++ /dev/null @@ -1,90 +0,0 @@ -package ee.carlrobert.codegpt.conversations; - -import static org.assertj.core.api.Assertions.assertThat; - -import com.intellij.testFramework.fixtures.BasePlatformTestCase; -import ee.carlrobert.codegpt.conversations.message.Message; -import ee.carlrobert.codegpt.settings.GeneralSettings; -import ee.carlrobert.codegpt.settings.service.ServiceType; -import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; -import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel; - -public class ConversationsStateTest extends BasePlatformTestCase { - - public void testStartNewDefaultConversation() { - GeneralSettings.getCurrentState().setSelectedService(ServiceType.OPENAI); - OpenAISettings.getCurrentState().setModel(OpenAIChatCompletionModel.GPT_3_5.getCode()); - - var conversation = ConversationService.getInstance().startConversation(); - - assertThat(conversation).isEqualTo(ConversationsState.getCurrentConversation()); - assertThat(conversation) - .extracting("clientCode", "model") - .containsExactly("chat.completion", "gpt-3.5-turbo"); - } - - public void testSaveConversation() { - var service = ConversationService.getInstance(); - var conversation = service.createConversation("chat.completion"); - service.addConversation(conversation); - var message = new Message("TEST_PROMPT"); - message.setResponse("TEST_RESPONSE"); - conversation.addMessage(message); - - service.saveConversation(conversation); - - var currentConversation = ConversationsState.getCurrentConversation(); - assertThat(currentConversation).isNotNull(); - assertThat(currentConversation.getMessages()) - .flatExtracting("prompt", "response") - .containsExactly("TEST_PROMPT", "TEST_RESPONSE"); - } - - public void testGetPreviousConversation() { - var service = ConversationService.getInstance(); - var firstConversation = service.startConversation(); - service.startConversation(); - - var previousConversation = service.getPreviousConversation(); - - assertThat(previousConversation.isPresent()).isTrue(); - assertThat(previousConversation.get()).isEqualTo(firstConversation); - } - - public void testGetNextConversation() { - var service = ConversationService.getInstance(); - var firstConversation = service.startConversation(); - var secondConversation = service.startConversation(); - ConversationsState.getInstance().setCurrentConversation(firstConversation); - - var nextConversation = service.getNextConversation(); - - assertThat(nextConversation.isPresent()).isTrue(); - assertThat(nextConversation.get()).isEqualTo(secondConversation); - } - - public void testDeleteSelectedConversation() { - var service = ConversationService.getInstance(); - var firstConversation = service.startConversation(); - service.startConversation(); - - service.deleteSelectedConversation(); - - assertThat(ConversationsState.getCurrentConversation()).isEqualTo(firstConversation); - assertThat(service.getSortedConversations().size()).isEqualTo(1); - assertThat(service.getSortedConversations()) - .extracting("id") - .containsExactly(firstConversation.getId()); - } - - public void testClearAllConversations() { - var service = ConversationService.getInstance(); - service.startConversation(); - service.startConversation(); - - service.clearAll(); - - assertThat(ConversationsState.getCurrentConversation()).isNull(); - assertThat(service.getSortedConversations().size()).isEqualTo(0); - } -} diff --git a/src/test/java/ee/carlrobert/codegpt/settings/state/GeneralSettingsTest.java b/src/test/java/ee/carlrobert/codegpt/settings/state/GeneralSettingsTest.java deleted file mode 100644 index 99c12cac..00000000 --- a/src/test/java/ee/carlrobert/codegpt/settings/state/GeneralSettingsTest.java +++ /dev/null @@ -1,81 +0,0 @@ -package ee.carlrobert.codegpt.settings.state; - -import static ee.carlrobert.codegpt.completions.HuggingFaceModel.CODE_LLAMA_7B_Q3; -import static org.assertj.core.api.Assertions.assertThat; - -import com.intellij.testFramework.fixtures.BasePlatformTestCase; -import ee.carlrobert.codegpt.completions.HuggingFaceModel; -import ee.carlrobert.codegpt.conversations.Conversation; -import ee.carlrobert.codegpt.settings.GeneralSettings; -import ee.carlrobert.codegpt.settings.service.ServiceType; -import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; -import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; - -public class GeneralSettingsTest extends BasePlatformTestCase { - - public void testOpenAISettingsSync() { - var openAISettings = OpenAISettings.getCurrentState(); - openAISettings.setModel("gpt-3.5-turbo"); - var conversation = new Conversation(); - conversation.setModel("gpt-4"); - conversation.setClientCode("chat.completion"); - var settings = GeneralSettings.getInstance(); - - settings.sync(conversation); - - assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.OPENAI); - assertThat(openAISettings.getModel()).isEqualTo("gpt-4"); - } - - public void testAzureSettingsSync() { - var settings = GeneralSettings.getInstance(); - var conversation = new Conversation(); - conversation.setModel("gpt-4"); - conversation.setClientCode("azure.chat.completion"); - - settings.sync(conversation); - - assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.AZURE); - } - - public void testYouSettingsSync() { - var settings = GeneralSettings.getInstance(); - var conversation = new Conversation(); - conversation.setModel("YouCode"); - conversation.setClientCode("you.chat.completion"); - - settings.sync(conversation); - - assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.YOU); - } - - public void testLlamaSettingsModelPathSync() { - var llamaSettings = LlamaSettings.getCurrentState(); - llamaSettings.setHuggingFaceModel(HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3); - var conversation = new Conversation(); - conversation.setModel("TEST_LLAMA_MODEL_PATH"); - conversation.setClientCode("llama.chat.completion"); - var settings = GeneralSettings.getInstance(); - - settings.sync(conversation); - - assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.LLAMA_CPP); - assertThat(llamaSettings.getCustomLlamaModelPath()).isEqualTo("TEST_LLAMA_MODEL_PATH"); - assertThat(llamaSettings.isUseCustomModel()).isTrue(); - } - - public void testLlamaSettingsHuggingFaceModelSync() { - var llamaSettings = LlamaSettings.getCurrentState(); - llamaSettings.setHuggingFaceModel(HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3); - var conversation = new Conversation(); - conversation.setModel("CODE_LLAMA_7B_Q3"); - conversation.setClientCode("llama.chat.completion"); - var settings = GeneralSettings.getInstance(); - - settings.sync(conversation); - - assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.LLAMA_CPP); - assertThat(llamaSettings.getHuggingFaceModel()).isEqualTo(CODE_LLAMA_7B_Q3); - assertThat(llamaSettings.isUseCustomModel()).isFalse(); - } -} diff --git a/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.java b/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.java deleted file mode 100644 index 4835894b..00000000 --- a/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.java +++ /dev/null @@ -1,415 +0,0 @@ -package ee.carlrobert.codegpt.toolwindow.chat; - -import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT; -import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT; -import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA; -import static ee.carlrobert.llm.client.util.JSONUtil.e; -import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray; -import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap; -import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse; -import static java.util.Objects.requireNonNull; -import static org.apache.http.HttpHeaders.AUTHORIZATION; -import static org.assertj.core.api.Assertions.assertThat; - -import ee.carlrobert.codegpt.CodeGPTKeys; -import ee.carlrobert.codegpt.EncodingManager; -import ee.carlrobert.codegpt.ReferencedFile; -import ee.carlrobert.codegpt.completions.ConversationType; -import ee.carlrobert.codegpt.completions.HuggingFaceModel; -import ee.carlrobert.codegpt.conversations.ConversationService; -import ee.carlrobert.codegpt.conversations.message.Message; -import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings; -import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; -import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Base64; -import java.util.List; -import java.util.Map; -import testsupport.IntegrationTest; - -public class ChatToolWindowTabPanelTest extends IntegrationTest { - - public void testSendingOpenAIMessage() { - useOpenAIService(); - ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT); - var message = new Message("Hello!"); - var conversation = ConversationService.getInstance().startConversation(); - var panel = new ChatToolWindowTabPanel(getProject(), conversation); - expectOpenAI((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions"); - assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); - assertThat(request.getBody()) - .extracting( - "model", - "messages") - .containsExactly( - "gpt-4", - List.of( - Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT), - Map.of("role", "user", "content", "Hello!"))); - return List.of( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))); - }); - - panel.sendMessage(message); - - waitExpecting(() -> { - var messages = conversation.getMessages(); - return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse()); - }); - var encodingManager = EncodingManager.getInstance(); - assertThat(panel.getTokenDetails()).extracting( - "systemPromptTokens", - "conversationTokens", - "userPromptTokens", - "highlightedTokens") - .containsExactly( - encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), - encodingManager.countTokens(message.getPrompt()), - 0, - 0); - assertThat(panel.getConversation()) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.getId(), - conversation.getModel(), - conversation.getClientCode(), - false); - var messages = panel.getConversation().getMessages(); - assertThat(messages.size()).isOne(); - assertThat(messages.get(0)) - .extracting("id", "prompt", "response") - .containsExactly(message.getId(), message.getPrompt(), message.getResponse()); - } - - public void testSendingOpenAIMessageWithReferencedContext() { - getProject().putUserData(CodeGPTKeys.SELECTED_FILES, List.of( - new ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"), - new ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"), - new ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3"))); - useOpenAIService(); - ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT); - var message = new Message("TEST_MESSAGE"); - message.setUserMessage("TEST_MESSAGE"); - message.setReferencedFilePaths( - List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")); - var conversation = ConversationService.getInstance().startConversation(); - var panel = new ChatToolWindowTabPanel(getProject(), conversation); - expectOpenAI((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions"); - assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); - assertThat(request.getBody()) - .extracting( - "model", - "messages") - .containsExactly( - "gpt-4", - List.of( - Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT), - Map.of("role", "user", "content", - """ - Use the following context to answer question at the end: - - File Path: TEST_FILE_PATH_1 - File Content: - ```TEST_FILE_NAME_1 - TEST_FILE_CONTENT_1 - ``` - - File Path: TEST_FILE_PATH_2 - File Content: - ```TEST_FILE_NAME_2 - TEST_FILE_CONTENT_2 - ``` - - File Path: TEST_FILE_PATH_3 - File Content: - ```TEST_FILE_NAME_3 - TEST_FILE_CONTENT_3 - ``` - - Question: TEST_MESSAGE"""))); - return List.of( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))); - }); - - panel.sendMessage(message); - - waitExpecting(() -> { - var messages = conversation.getMessages(); - return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse()); - }); - var encodingManager = EncodingManager.getInstance(); - assertThat(panel.getTokenDetails()).extracting( - "systemPromptTokens", - "conversationTokens", - "userPromptTokens", - "highlightedTokens") - .containsExactly( - encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), - encodingManager.countTokens(message.getPrompt()), - 0, - 0); - assertThat(panel.getConversation()) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.getId(), - conversation.getModel(), - conversation.getClientCode(), - false); - var messages = panel.getConversation().getMessages(); - assertThat(messages.size()).isOne(); - assertThat(messages.get(0)) - .extracting("id", "prompt", "response", "referencedFilePaths") - .containsExactly( - message.getId(), - message.getPrompt(), - message.getResponse(), - List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")); - } - - public void testSendingOpenAIMessageWithImage() { - var testImagePath = requireNonNull(getClass().getResource("/images/test-image.png")).getPath(); - getProject().putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath); - useOpenAIService("gpt-4-vision-preview"); - ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT); - var message = new Message("TEST_MESSAGE"); - var conversation = ConversationService.getInstance().startConversation(); - var panel = new ChatToolWindowTabPanel(getProject(), conversation); - expectOpenAI((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions"); - assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); - try { - var testImageUrl = "data:image/png;base64," - + Base64.getEncoder().encodeToString(Files.readAllBytes(Path.of(testImagePath))); - assertThat(request.getBody()) - .extracting("model", "messages") - .containsExactly( - "gpt-4-vision-preview", - List.of( - Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT), - Map.of("role", "user", "content", List.of( - Map.of( - "type", "image_url", - "image_url", Map.of("url", testImageUrl)), - Map.of("type", "text", "text", "TEST_MESSAGE") - )))); - } catch (IOException e) { - throw new RuntimeException(e); - } - return List.of( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))); - }); - - panel.sendMessage(message); - - waitExpecting(() -> { - var messages = conversation.getMessages(); - return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse()); - }); - var encodingManager = EncodingManager.getInstance(); - assertThat(panel.getTokenDetails()).extracting( - "systemPromptTokens", - "conversationTokens", - "userPromptTokens", - "highlightedTokens") - .containsExactly( - encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), - encodingManager.countTokens(message.getPrompt()), - 0, - 0); - assertThat(panel.getConversation()) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.getId(), - conversation.getModel(), - conversation.getClientCode(), - false); - var messages = panel.getConversation().getMessages(); - assertThat(messages.size()).isOne(); - assertThat(messages.get(0)) - .extracting("id", "prompt", "response", "imageFilePath") - .containsExactly( - message.getId(), - message.getPrompt(), - message.getResponse(), - message.getImageFilePath()); - } - - public void testFixCompileErrorsWithOpenAIService() { - getProject().putUserData(CodeGPTKeys.SELECTED_FILES, List.of( - new ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"), - new ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"), - new ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3"))); - useOpenAIService(); - ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT); - var message = new Message("TEST_MESSAGE"); - message.setUserMessage("TEST_MESSAGE"); - message.setReferencedFilePaths( - List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")); - var conversation = ConversationService.getInstance().startConversation(); - var panel = new ChatToolWindowTabPanel(getProject(), conversation); - expectOpenAI((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions"); - assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); - assertThat(request.getBody()) - .extracting( - "model", - "messages") - .containsExactly( - "gpt-4", - List.of( - Map.of("role", "system", "content", FIX_COMPILE_ERRORS_SYSTEM_PROMPT), - Map.of("role", "user", "content", - """ - Use the following context to answer question at the end: - - File Path: TEST_FILE_PATH_1 - File Content: - ```TEST_FILE_NAME_1 - TEST_FILE_CONTENT_1 - ``` - - File Path: TEST_FILE_PATH_2 - File Content: - ```TEST_FILE_NAME_2 - TEST_FILE_CONTENT_2 - ``` - - File Path: TEST_FILE_PATH_3 - File Content: - ```TEST_FILE_NAME_3 - TEST_FILE_CONTENT_3 - ``` - - Question: TEST_MESSAGE"""))); - return List.of( - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), - jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))); - }); - - panel.sendMessage(message, ConversationType.FIX_COMPILE_ERRORS); - - waitExpecting(() -> { - var messages = conversation.getMessages(); - return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse()); - }); - var encodingManager = EncodingManager.getInstance(); - assertThat(panel.getTokenDetails()).extracting( - "systemPromptTokens", - "conversationTokens", - "userPromptTokens", - "highlightedTokens") - .containsExactly( - encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), - encodingManager.countTokens(message.getPrompt()), - 0, - 0); - assertThat(panel.getConversation()) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.getId(), - conversation.getModel(), - conversation.getClientCode(), - false); - var messages = panel.getConversation().getMessages(); - assertThat(messages.size()).isOne(); - assertThat(messages.get(0)) - .extracting("id", "prompt", "response", "referencedFilePaths") - .containsExactly( - message.getId(), - message.getPrompt(), - message.getResponse(), - List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")); - } - - public void testSendingLlamaMessage() { - useLlamaService(); - var configurationState = ConfigurationSettings.getCurrentState(); - configurationState.setSystemPrompt(COMPLETION_SYSTEM_PROMPT); - configurationState.setMaxTokens(1000); - configurationState.setTemperature(0.1); - var llamaSettings = LlamaSettings.getCurrentState(); - llamaSettings.setUseCustomModel(false); - llamaSettings.setHuggingFaceModel(HuggingFaceModel.CODE_LLAMA_7B_Q4); - llamaSettings.setTopK(30); - llamaSettings.setTopP(0.8); - llamaSettings.setMinP(0.03); - llamaSettings.setRepeatPenalty(1.3); - var message = new Message("TEST_PROMPT"); - var conversation = ConversationService.getInstance().startConversation(); - var panel = new ChatToolWindowTabPanel(getProject(), conversation); - expectLlama((StreamHttpExchange) request -> { - assertThat(request.getUri().getPath()).isEqualTo("/completion"); - assertThat(request.getBody()) - .extracting( - "prompt", - "n_predict", - "stream", - "temperature", - "top_k", - "top_p", - "min_p", - "repeat_penalty") - .containsExactly( - LLAMA.buildPrompt( - COMPLETION_SYSTEM_PROMPT, - "TEST_PROMPT", - conversation.getMessages()), - configurationState.getMaxTokens(), - true, - configurationState.getTemperature(), - llamaSettings.getTopK(), - llamaSettings.getTopP(), - llamaSettings.getMinP(), - llamaSettings.getRepeatPenalty()); - return List.of( - jsonMapResponse("content", "Hel"), - jsonMapResponse("content", "lo!"), - jsonMapResponse( - e("content", ""), - e("stop", true))); - }); - - panel.sendMessage(message, ConversationType.DEFAULT); - - waitExpecting(() -> { - var messages = conversation.getMessages(); - return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse()); - }); - assertThat(panel.getConversation()) - .isNotNull() - .extracting("id", "model", "clientCode", "discardTokenLimit") - .containsExactly( - conversation.getId(), - conversation.getModel(), - conversation.getClientCode(), - false); - var messages = panel.getConversation().getMessages(); - assertThat(messages.size()).isOne(); - assertThat(messages.get(0)) - .extracting("id", "prompt", "response") - .containsExactly(message.getId(), message.getPrompt(), message.getResponse()); - } -} diff --git a/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.java b/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.java deleted file mode 100644 index 0746f538..00000000 --- a/src/test/java/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.java +++ /dev/null @@ -1,50 +0,0 @@ -package ee.carlrobert.codegpt.toolwindow.chat; - -import static org.assertj.core.api.Assertions.assertThat; - -import com.intellij.openapi.util.Disposer; -import com.intellij.testFramework.fixtures.BasePlatformTestCase; -import ee.carlrobert.codegpt.conversations.ConversationService; -import ee.carlrobert.codegpt.conversations.message.Message; - -public class ChatToolWindowTabbedPaneTest extends BasePlatformTestCase { - - public void testClearAllTabs() { - var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable()); - tabbedPane.addNewTab(createNewTabPanel()); - - tabbedPane.clearAll(); - - assertThat(tabbedPane.getActiveTabMapping()).isEmpty(); - } - - - public void testAddingNewTabs() { - var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable()); - - tabbedPane.addNewTab(createNewTabPanel()); - tabbedPane.addNewTab(createNewTabPanel()); - tabbedPane.addNewTab(createNewTabPanel()); - - assertThat(tabbedPane.getActiveTabMapping().keySet()) - .containsExactly("Chat 1", "Chat 2", "Chat 3"); - } - - public void testResetCurrentlyActiveTabPanel() { - var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable()); - var conversation = ConversationService.getInstance().startConversation(); - conversation.addMessage(new Message("TEST_PROMPT", "TEST_RESPONSE")); - tabbedPane.addNewTab(new ChatToolWindowTabPanel(getProject(), conversation)); - - tabbedPane.resetCurrentlyActiveTabPanel(getProject()); - - var tabPanel = tabbedPane.getActiveTabMapping().get("Chat 1"); - assertThat(tabPanel.getConversation().getMessages()).isEmpty(); - } - - private ChatToolWindowTabPanel createNewTabPanel() { - return new ChatToolWindowTabPanel( - getProject(), - ConversationService.getInstance().startConversation()); - } -} diff --git a/src/test/java/testsupport/IntegrationTest.java b/src/test/java/testsupport/IntegrationTest.java deleted file mode 100644 index f92be8e2..00000000 --- a/src/test/java/testsupport/IntegrationTest.java +++ /dev/null @@ -1,34 +0,0 @@ -package testsupport; - -import com.intellij.openapi.util.Key; -import com.intellij.testFramework.fixtures.BasePlatformTestCase; -import ee.carlrobert.codegpt.CodeGPTKeys; -import ee.carlrobert.llm.client.mixin.ExternalServiceTestMixin; -import java.util.Collections; -import testsupport.mixin.ShortcutsTestMixin; - -public class IntegrationTest extends BasePlatformTestCase implements - ExternalServiceTestMixin, - ShortcutsTestMixin { - - static { - ExternalServiceTestMixin.init(); - } - - @Override - protected void tearDown() throws Exception { - ExternalServiceTestMixin.clearAll(); - clearKeys(); - super.tearDown(); - } - - private void clearKeys() { - putUserData(CodeGPTKeys.SELECTED_FILES, Collections.emptyList()); - putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, ""); - putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, ""); - } - - private void putUserData(Key key, T value) { - getProject().putUserData(key, value); - } -} diff --git a/src/test/java/testsupport/mixin/ShortcutsTestMixin.java b/src/test/java/testsupport/mixin/ShortcutsTestMixin.java deleted file mode 100644 index 9c7d4aa2..00000000 --- a/src/test/java/testsupport/mixin/ShortcutsTestMixin.java +++ /dev/null @@ -1,51 +0,0 @@ -package testsupport.mixin; - -import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_OPENAI_API_KEY; -import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY; - -import com.intellij.testFramework.PlatformTestUtil; -import ee.carlrobert.codegpt.credentials.CredentialsStore; -import ee.carlrobert.codegpt.settings.GeneralSettings; -import ee.carlrobert.codegpt.settings.service.ServiceType; -import ee.carlrobert.codegpt.settings.service.azure.AzureSettings; -import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; -import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; -import java.util.function.BooleanSupplier; - -public interface ShortcutsTestMixin { - - default void useOpenAIService() { - useOpenAIService("gpt-4"); - } - - default void useOpenAIService(String model) { - GeneralSettings.getCurrentState().setSelectedService(ServiceType.OPENAI); - CredentialsStore.INSTANCE.setCredential(OPENAI_API_KEY, "TEST_API_KEY"); - OpenAISettings.getCurrentState().setModel(model); - } - - default void useAzureService() { - GeneralSettings.getCurrentState().setSelectedService(ServiceType.AZURE); - CredentialsStore.INSTANCE.setCredential(AZURE_OPENAI_API_KEY, "TEST_API_KEY"); - var azureSettings = AzureSettings.getCurrentState(); - azureSettings.setResourceName("TEST_RESOURCE_NAME"); - azureSettings.setApiVersion("TEST_API_VERSION"); - azureSettings.setDeploymentId("TEST_DEPLOYMENT_ID"); - } - - default void useYouService() { - GeneralSettings.getCurrentState().setSelectedService(ServiceType.YOU); - } - - default void useLlamaService() { - GeneralSettings.getCurrentState().setSelectedService(ServiceType.LLAMA_CPP); - LlamaSettings.getCurrentState().setServerPort(null); - } - - default void waitExpecting(BooleanSupplier condition) { - PlatformTestUtil.waitWithEventsDispatching( - "Waiting for message response timed out or did not meet expected conditions", - condition, - 5); - } -} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt new file mode 100644 index 00000000..d92c00b9 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.kt @@ -0,0 +1,50 @@ +package ee.carlrobert.codegpt.codecompletions + +import com.intellij.openapi.editor.VisualPosition +import ee.carlrobert.codegpt.CodeGPTKeys.PREVIOUS_INLAY_TEXT +import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings +import ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent +import ee.carlrobert.llm.client.http.RequestEntity +import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange +import ee.carlrobert.llm.client.util.JSONUtil.e +import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse +import org.assertj.core.api.Assertions.assertThat +import testsupport.IntegrationTest + +class CodeCompletionServiceTest : IntegrationTest() { + + private val cursorPosition = VisualPosition(3, 0) + + fun testFetchCodeCompletionLlama() { + useLlamaService() + LlamaSettings.getCurrentState().isCodeCompletionsEnabled = true + myFixture.configureByText( + "CompletionTest.java", + getResourceContent("/codecompletions/code-completion-file.txt") + ) + myFixture.editor.caretModel.moveToVisualPosition(cursorPosition) + val expectedCompletion = "TEST_OUTPUT" + val prefix = """ + ${"z".repeat(245)} + [INPUT] + c + """.trimIndent() // 128 tokens + val suffix = """ + + [\INPUT] + ${"z".repeat(247)} + """.trimIndent() // 128 tokens + expectLlama(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/completion") + assertThat(request.method).isEqualTo("POST") + assertThat(request.body) + .extracting("prompt") + .isEqualTo(InfillPromptTemplate.LLAMA.buildPrompt(prefix, suffix)) + listOf(jsonMapResponse(e("content", expectedCompletion), e("stop", true))) + }) + + myFixture.type('c') + + waitExpecting { "TEST_OUTPUT" == PREVIOUS_INLAY_TEXT[myFixture.editor] } + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt new file mode 100644 index 00000000..97613aeb --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/CompletionRequestProviderTest.kt @@ -0,0 +1,152 @@ +package ee.carlrobert.codegpt.completions + +import ee.carlrobert.codegpt.conversations.ConversationService +import ee.carlrobert.codegpt.conversations.message.Message +import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey +import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.groups.Tuple +import testsupport.IntegrationTest + +class CompletionRequestProviderTest : IntegrationTest() { + + fun testChatCompletionRequestWithSystemPromptOverride() { + setCredential(CredentialKey.OPENAI_API_KEY, "TEST_API_KEY") + ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT" + val conversation = ConversationService.getInstance().startConversation() + val firstMessage = createDummyMessage(500) + val secondMessage = createDummyMessage(250) + conversation.addMessage(firstMessage) + conversation.addMessage(secondMessage) + + val request = CompletionRequestProvider(conversation) + .buildOpenAIChatCompletionRequest( + OpenAIChatCompletionModel.GPT_3_5.code, + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + false)) + + assertThat(request.messages) + .extracting("role", "content") + .containsExactly( + Tuple.tuple("system", "TEST_SYSTEM_PROMPT"), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", firstMessage.response), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", secondMessage.response), + Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT")) + } + + fun testChatCompletionRequestWithoutSystemPromptOverride() { + val conversation = ConversationService.getInstance().startConversation() + val firstMessage = createDummyMessage(500) + val secondMessage = createDummyMessage(250) + conversation.addMessage(firstMessage) + conversation.addMessage(secondMessage) + + val request = CompletionRequestProvider(conversation) + .buildOpenAIChatCompletionRequest( + OpenAIChatCompletionModel.GPT_3_5.code, + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + false)) + + assertThat(request.messages) + .extracting("role", "content") + .containsExactly( + Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", firstMessage.response), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", secondMessage.response), + Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT")) + } + + fun testChatCompletionRequestRetry() { + ConfigurationSettings.getCurrentState().systemPrompt = CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT + val conversation = ConversationService.getInstance().startConversation() + val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500) + val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250) + conversation.addMessage(firstMessage) + conversation.addMessage(secondMessage) + + val request = CompletionRequestProvider(conversation) + .buildOpenAIChatCompletionRequest( + OpenAIChatCompletionModel.GPT_3_5.code, + CallParameters( + conversation, + ConversationType.DEFAULT, + secondMessage, + true)) + + assertThat(request.messages) + .extracting("role", "content") + .containsExactly( + Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT), + Tuple.tuple("user", "FIRST_TEST_PROMPT"), + Tuple.tuple("assistant", firstMessage.response), + Tuple.tuple("user", "SECOND_TEST_PROMPT")) + } + + fun testReducedChatCompletionRequest() { + val conversation = ConversationService.getInstance().startConversation() + conversation.addMessage(createDummyMessage(50)) + conversation.addMessage(createDummyMessage(100)) + conversation.addMessage(createDummyMessage(150)) + conversation.addMessage(createDummyMessage(1000)) + val remainingMessage = createDummyMessage(2000) + conversation.addMessage(remainingMessage) + conversation.discardTokenLimits() + + val request = CompletionRequestProvider(conversation) + .buildOpenAIChatCompletionRequest( + OpenAIChatCompletionModel.GPT_3_5.code, + CallParameters( + conversation, + ConversationType.DEFAULT, + Message("TEST_CHAT_COMPLETION_PROMPT"), + false)) + + assertThat(request.messages) + .extracting("role", "content") + .containsExactly( + Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT), + Tuple.tuple("user", "TEST_PROMPT"), + Tuple.tuple("assistant", remainingMessage.response), + Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT")) + } + + fun testTotalUsageExceededException() { + val conversation = ConversationService.getInstance().startConversation() + conversation.addMessage(createDummyMessage(1500)) + conversation.addMessage(createDummyMessage(1500)) + conversation.addMessage(createDummyMessage(1500)) + + assertThrows(TotalUsageExceededException::class.java) { + CompletionRequestProvider(conversation) + .buildOpenAIChatCompletionRequest( + OpenAIChatCompletionModel.GPT_3_5.code, + CallParameters( + conversation, + ConversationType.DEFAULT, + createDummyMessage(100), + false)) } + } + + private fun createDummyMessage(tokenSize: Int): Message { + return createDummyMessage("TEST_PROMPT", tokenSize) + } + + private fun createDummyMessage(prompt: String, tokenSize: Int): Message { + val message = Message(prompt) + // 'zz' = 1 token, prompt = 6 tokens, 7 tokens per message (GPT-3), + message.response = "zz".repeat((tokenSize) - 6 - 7) + return message + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt new file mode 100644 index 00000000..181f6ee5 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.kt @@ -0,0 +1,179 @@ +package ee.carlrobert.codegpt.completions + +import ee.carlrobert.codegpt.CodeGPTPlugin +import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA +import ee.carlrobert.codegpt.conversations.ConversationService +import ee.carlrobert.codegpt.conversations.message.Message +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.llm.client.http.RequestEntity +import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange +import ee.carlrobert.llm.client.util.JSONUtil.e +import ee.carlrobert.llm.client.util.JSONUtil.jsonArray +import ee.carlrobert.llm.client.util.JSONUtil.jsonMap +import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse +import org.apache.http.HttpHeaders +import org.assertj.core.api.Assertions.assertThat +import testsupport.IntegrationTest + +class DefaultCompletionRequestHandlerTest : IntegrationTest() { + + fun testOpenAIChatCompletionCall() { + useOpenAIService() + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages") + .containsExactly( + "gpt-4", + listOf( + mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT), + mapOf("role" to "user", "content" to "TEST_PROMPT"))) + listOf( + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) + }) + + requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false)) + + waitExpecting { "Hello!" == message.response } + } + + fun testAzureChatCompletionCall() { + useAzureService() + val conversationService = ConversationService.getInstance() + val prevMessage = Message("TEST_PREV_PROMPT") + prevMessage.response = "TEST_PREV_RESPONSE" + val conversation = conversationService.startConversation() + conversation.addMessage(prevMessage) + conversationService.saveConversation(conversation) + expectAzure(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo( + "/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions") + assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION") + assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY") + assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt") + assertThat(request.body) + .extracting("messages") + .isEqualTo( + listOf( + mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT), + mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"), + mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"), + mapOf("role" to "user", "content" to "TEST_PROMPT"))) + listOf( + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) + }) + val message = Message("TEST_PROMPT") + val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + + requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false)) + + waitExpecting { "Hello!" == message.response } + } + + fun testYouChatCompletionCall() { + useYouService() + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + conversation.addMessage(Message("Ping", "Pong")) + val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + expectYou(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/api/streamingSearch") + assertThat(request.method).isEqualTo("GET") + assertThat(request.uri.path).isEqualTo("/api/streamingSearch") + assertThat(request.uri.query).isEqualTo( + "q=TEST_PROMPT&" + + "page=1&" + + "cfr=CodeGPT&" + + "count=10&" + + "safeSearch=WebPages,Translations,TimeZone,Computation,RelatedSearches&" + + "domain=youchat&" + + "selectedChatMode=default&" + + "chat=[{\"question\":\"Ping\",\"answer\":\"Pong\"}]&" + + "utm_source=ide&" + + "utm_medium=jetbrains&" + + "utm_campaign=" + CodeGPTPlugin.getVersion() + "&" + + "utm_content=CodeGPT") + assertThat(request.headers) + .flatExtracting("Accept", "Connection", "User-agent", "Cookie") + .containsExactly( + "text/event-stream", + "Keep-Alive", + "youide CodeGPT", + "safesearch_guest=Moderate; " + + "youpro_subscription=true; " + + "you_subscription=free; " + + "stytch_session=; " + + "ydc_stytch_session=; " + + "stytch_session_jwt=; " + + "ydc_stytch_session_jwt=; " + + "eg4=false; " + + "__cf_bm=aN2b3pQMH8XADeMB7bg9s1bJ_bfXBcCHophfOGRg6g0-1693601599-0-" + + "AWIt5Mr4Y3xQI4mIJ1lSf4+vijWKDobrty8OopDeBxY+NABe0MRFidF3dCUoWjRt8" + + "SVMvBZPI3zkOgcRs7Mz3yazd7f7c58HwW5Xg9jdBjNg;") + listOf( + jsonMapResponse("youChatToken", "Hel"), + jsonMapResponse("youChatToken", "lo"), + jsonMapResponse("youChatToken", "!")) + }) + + requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false)) + + waitExpecting { "Hello!" == message.response } + } + + fun testLlamaChatCompletionCall() { + useLlamaService() + ConfigurationSettings.getCurrentState().maxTokens = 99 + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + conversation.addMessage(Message("Ping", "Pong")) + val requestHandler = CompletionRequestHandler(getRequestEventListener(message)) + expectLlama(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/completion") + assertThat(request.body) + .extracting( + "prompt", + "n_predict", + "stream") + .containsExactly( + LLAMA.buildPrompt( + COMPLETION_SYSTEM_PROMPT, + "TEST_PROMPT", + conversation.messages), + 99, + true) + listOf( + jsonMapResponse("content", "Hel"), + jsonMapResponse("content", "lo!"), + jsonMapResponse( + e("content", ""), + e("stop", true))) + }) + + requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false)) + + waitExpecting { "Hello!" == message.response } + } + + private fun getRequestEventListener(message: Message): CompletionResponseEventListener { + return object : CompletionResponseEventListener { + override fun handleCompleted(fullMessage: String, callParameters: CallParameters) { + message.response = fullMessage + } + } + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt new file mode 100644 index 00000000..8c7ff407 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/completions/PromptTemplateTest.kt @@ -0,0 +1,149 @@ +package ee.carlrobert.codegpt.completions + +import ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA +import ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML +import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA +import ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA +import ee.carlrobert.codegpt.conversations.message.Message +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class PromptTemplateTest { + + @Test + fun shouldBuildLlamaPromptWithHistory() { + val prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY) + + assertThat(prompt).isEqualTo(""" + <>TEST_SYSTEM_PROMPT<> + [INST]TEST_PREV_PROMPT_1[/INST] + TEST_PREV_RESPONSE_1 + [INST]TEST_PREV_PROMPT_2[/INST] + TEST_PREV_RESPONSE_2 + [INST]TEST_USER_PROMPT[/INST] + """.trimIndent()) + } + + @Test + fun shouldBuildLlamaPromptWithoutHistory() { + val prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf()) + + assertThat(prompt).isEqualTo(""" + <>TEST_SYSTEM_PROMPT<> + [INST]TEST_USER_PROMPT[/INST] + """.trimIndent()) + } + + @Test + fun shouldBuildAlpacaPromptWithHistory() { + val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY) + + assertThat(prompt).isEqualTo(""" + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction + TEST_PREV_PROMPT_1 + + ### Response: + TEST_PREV_RESPONSE_1 + + ### Instruction + TEST_PREV_PROMPT_2 + + ### Response: + TEST_PREV_RESPONSE_2 + + ### Instruction + TEST_USER_PROMPT + + ### Response: + + """.trimIndent()) + } + + @Test + fun shouldBuildAlpacaPromptWithoutHistory() { + val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf()) + + assertThat(prompt).isEqualTo(""" + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + ### Instruction + TEST_USER_PROMPT + + ### Response: + + """.trimIndent()) + } + + @Test + fun shouldBuildChatMLPromptWithHistory() { + val prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY) + + assertThat(prompt).isEqualTo(""" + <|im_start|>system + TEST_SYSTEM_PROMPT<|im_end|> + <|im_start|>user + TEST_PREV_PROMPT_1<|im_end|> + <|im_start|>assistant + TEST_PREV_RESPONSE_1<|im_end|> + <|im_start|>user + TEST_PREV_PROMPT_2<|im_end|> + <|im_start|>assistant + TEST_PREV_RESPONSE_2<|im_end|> + <|im_start|>user + TEST_USER_PROMPT<|im_end|> + """.trimIndent()) + } + + @Test + fun shouldBuildChatMLPromptWithoutHistory() { + val prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf()) + + assertThat(prompt).isEqualTo(""" + <|im_start|>system + TEST_SYSTEM_PROMPT<|im_end|> + <|im_start|>user + TEST_USER_PROMPT<|im_end|> + """.trimIndent()) + } + + @Test + fun shouldBuildToRAPromptWithHistory() { + val prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY) + + assertThat(prompt).isEqualTo(""" + <|user|> + TEST_PREV_PROMPT_1 + <|assistant|> + TEST_PREV_RESPONSE_1 + <|user|> + TEST_PREV_PROMPT_2 + <|assistant|> + TEST_PREV_RESPONSE_2 + <|user|> + TEST_USER_PROMPT + <|assistant|> + """.trimIndent()) + } + + @Test + fun shouldBuildToRAPromptWithoutHistory() { + val prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf()) + + assertThat(prompt).isEqualTo(""" + <|user|> + TEST_USER_PROMPT + <|assistant|> + """.trimIndent()) + } + + companion object { + private const val SYSTEM_PROMPT = "TEST_SYSTEM_PROMPT" + private const val USER_PROMPT = "TEST_USER_PROMPT" + private val HISTORY: List = listOf( + Message("TEST_PREV_PROMPT_1", "TEST_PREV_RESPONSE_1"), + Message("TEST_PREV_PROMPT_2", "TEST_PREV_RESPONSE_2") + ) + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/conversations/ConversationsStateTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/conversations/ConversationsStateTest.kt new file mode 100644 index 00000000..4866a7b6 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/conversations/ConversationsStateTest.kt @@ -0,0 +1,89 @@ +package ee.carlrobert.codegpt.conversations + +import com.intellij.testFramework.fixtures.BasePlatformTestCase +import ee.carlrobert.codegpt.conversations.message.Message +import ee.carlrobert.codegpt.settings.GeneralSettings +import ee.carlrobert.codegpt.settings.service.ServiceType +import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings +import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel +import org.assertj.core.api.Assertions.assertThat + +class ConversationsStateTest : BasePlatformTestCase() { + + fun testStartNewDefaultConversation() { + GeneralSettings.getCurrentState().selectedService = ServiceType.OPENAI + OpenAISettings.getCurrentState().model = OpenAIChatCompletionModel.GPT_3_5.code + + val conversation = ConversationService.getInstance().startConversation() + + assertThat(conversation).isEqualTo(ConversationsState.getCurrentConversation()) + assertThat(conversation) + .extracting("clientCode", "model") + .containsExactly("chat.completion", "gpt-3.5-turbo") + } + + fun testSaveConversation() { + val service = ConversationService.getInstance() + val conversation = service.createConversation("chat.completion") + service.addConversation(conversation) + val message = Message("TEST_PROMPT") + message.response = "TEST_RESPONSE" + conversation.addMessage(message) + + service.saveConversation(conversation) + + val currentConversation = ConversationsState.getCurrentConversation() + assertThat(currentConversation).isNotNull() + assertThat(currentConversation!!.messages) + .flatExtracting("prompt", "response") + .containsExactly("TEST_PROMPT", "TEST_RESPONSE") + } + + fun testGetPreviousConversation() { + val service = ConversationService.getInstance() + val firstConversation = service.startConversation() + service.startConversation() + + val previousConversation = service.previousConversation + + assertThat(previousConversation.isPresent).isTrue() + assertThat(previousConversation.get()).isEqualTo(firstConversation) + } + + fun testGetNextConversation() { + val service = ConversationService.getInstance() + val firstConversation = service.startConversation() + val secondConversation = service.startConversation() + ConversationsState.getInstance().setCurrentConversation(firstConversation) + + val nextConversation = service.nextConversation + + assertThat(nextConversation.isPresent).isTrue() + assertThat(nextConversation.get()).isEqualTo(secondConversation) + } + + fun testDeleteSelectedConversation() { + val service = ConversationService.getInstance() + val firstConversation = service.startConversation() + service.startConversation() + + service.deleteSelectedConversation() + + assertThat(ConversationsState.getCurrentConversation()).isEqualTo(firstConversation) + assertThat(service.sortedConversations.size).isEqualTo(1) + assertThat(service.sortedConversations) + .extracting("id") + .containsExactly(firstConversation.id) + } + + fun testClearAllConversations() { + val service = ConversationService.getInstance() + service.startConversation() + service.startConversation() + + service.clearAll() + + assertThat(ConversationsState.getCurrentConversation()).isNull() + assertThat(service.sortedConversations.size).isEqualTo(0) + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/settings/state/GeneralSettingsTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/settings/state/GeneralSettingsTest.kt new file mode 100644 index 00000000..748f466e --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/settings/state/GeneralSettingsTest.kt @@ -0,0 +1,79 @@ +package ee.carlrobert.codegpt.settings.state + +import com.intellij.testFramework.fixtures.BasePlatformTestCase +import ee.carlrobert.codegpt.completions.HuggingFaceModel +import ee.carlrobert.codegpt.conversations.Conversation +import ee.carlrobert.codegpt.settings.GeneralSettings +import ee.carlrobert.codegpt.settings.service.ServiceType +import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings +import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings +import org.assertj.core.api.Assertions.assertThat + +class GeneralSettingsTest : BasePlatformTestCase() { + + fun testOpenAISettingsSync() { + val openAISettings = OpenAISettings.getCurrentState() + openAISettings.model = "gpt-3.5-turbo" + val conversation = Conversation() + conversation.model = "gpt-4" + conversation.clientCode = "chat.completion" + val settings = GeneralSettings.getInstance() + + settings.sync(conversation) + + assertThat(settings.state.selectedService).isEqualTo(ServiceType.OPENAI) + assertThat(openAISettings.model).isEqualTo("gpt-4") + } + + fun testAzureSettingsSync() { + val settings = GeneralSettings.getInstance() + val conversation = Conversation() + conversation.model = "gpt-4" + conversation.clientCode = "azure.chat.completion" + + settings.sync(conversation) + + assertThat(settings.state.selectedService).isEqualTo(ServiceType.AZURE) + } + + fun testYouSettingsSync() { + val settings = GeneralSettings.getInstance() + val conversation = Conversation() + conversation.model = "YouCode" + conversation.clientCode = "you.chat.completion" + + settings.sync(conversation) + + assertThat(settings.state.selectedService).isEqualTo(ServiceType.YOU) + } + + fun testLlamaSettingsModelPathSync() { + val llamaSettings = LlamaSettings.getCurrentState() + llamaSettings.huggingFaceModel = HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3 + val conversation = Conversation() + conversation.model = "TEST_LLAMA_MODEL_PATH" + conversation.clientCode = "llama.chat.completion" + val settings = GeneralSettings.getInstance() + + settings.sync(conversation) + + assertThat(settings.state.selectedService).isEqualTo(ServiceType.LLAMA_CPP) + assertThat(llamaSettings.customLlamaModelPath).isEqualTo("TEST_LLAMA_MODEL_PATH") + assertThat(llamaSettings.isUseCustomModel).isTrue() + } + + fun testLlamaSettingsHuggingFaceModelSync() { + val llamaSettings = LlamaSettings.getCurrentState() + llamaSettings.huggingFaceModel = HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3 + val conversation = Conversation() + conversation.model = "CODE_LLAMA_7B_Q3" + conversation.clientCode = "llama.chat.completion" + val settings = GeneralSettings.getInstance() + + settings.sync(conversation) + + assertThat(settings.state.selectedService).isEqualTo(ServiceType.LLAMA_CPP) + assertThat(llamaSettings.huggingFaceModel).isEqualTo(HuggingFaceModel.CODE_LLAMA_7B_Q3) + assertThat(llamaSettings.isUseCustomModel).isFalse() + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt new file mode 100644 index 00000000..c1a4aede --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabPanelTest.kt @@ -0,0 +1,410 @@ +package ee.carlrobert.codegpt.toolwindow.chat + +import ee.carlrobert.codegpt.CodeGPTKeys +import ee.carlrobert.codegpt.EncodingManager +import ee.carlrobert.codegpt.ReferencedFile +import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT +import ee.carlrobert.codegpt.completions.ConversationType +import ee.carlrobert.codegpt.completions.HuggingFaceModel +import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA +import ee.carlrobert.codegpt.conversations.ConversationService +import ee.carlrobert.codegpt.conversations.message.Message +import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings +import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings +import ee.carlrobert.llm.client.http.RequestEntity +import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange +import ee.carlrobert.llm.client.util.JSONUtil.e +import ee.carlrobert.llm.client.util.JSONUtil.jsonArray +import ee.carlrobert.llm.client.util.JSONUtil.jsonMap +import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse +import org.apache.http.HttpHeaders +import org.assertj.core.api.Assertions.assertThat +import testsupport.IntegrationTest +import java.io.IOException +import java.nio.file.Files +import java.nio.file.Path +import java.util.Base64 +import java.util.Objects + +class ChatToolWindowTabPanelTest : IntegrationTest() { + + fun testSendingOpenAIMessage() { + useOpenAIService() + ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT + val message = Message("Hello!") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages") + .containsExactly( + "gpt-4", + listOf( + mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT), + mapOf("role" to "user", "content" to "Hello!"))) + listOf( + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) + }) + + panel.sendMessage(message) + + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + } + val encodingManager = EncodingManager.getInstance() + assertThat(panel.tokenDetails).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens") + .containsExactly( + encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), + encodingManager.countTokens(message.prompt), + 0, + 0) + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response") + .containsExactly(message.id, message.prompt, message.response) + } + + fun testSendingOpenAIMessageWithReferencedContext() { + project.putUserData(CodeGPTKeys.SELECTED_FILES, listOf( + ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"), + ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"), + ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3"))) + useOpenAIService() + ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT + val message = Message("TEST_MESSAGE") + message.userMessage = "TEST_MESSAGE" + message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages") + .containsExactly( + "gpt-4", + listOf( + mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT), + mapOf("role" to "user", "content" to """ + Use the following context to answer question at the end: + + File Path: TEST_FILE_PATH_1 + File Content: + ```TEST_FILE_NAME_1 + TEST_FILE_CONTENT_1 + ``` + + File Path: TEST_FILE_PATH_2 + File Content: + ```TEST_FILE_NAME_2 + TEST_FILE_CONTENT_2 + ``` + + File Path: TEST_FILE_PATH_3 + File Content: + ```TEST_FILE_NAME_3 + TEST_FILE_CONTENT_3 + ``` + + Question: TEST_MESSAGE""".trimIndent()))) + listOf( + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) + }) + + panel.sendMessage(message) + + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + } + val encodingManager = EncodingManager.getInstance() + assertThat(panel.tokenDetails).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens") + .containsExactly( + encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), + encodingManager.countTokens(message.prompt), + 0, + 0) + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response", "referencedFilePaths") + .containsExactly( + message.id, + message.prompt, + message.response, + listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")) + } + + fun testSendingOpenAIMessageWithImage() { + val testImagePath = Objects.requireNonNull(javaClass.getResource("/images/test-image.png")).path + project.putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath) + useOpenAIService("gpt-4-vision-preview") + ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT + val message = Message("TEST_MESSAGE") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + try { + val testImageUrl = ("data:image/png;base64," + + Base64.getEncoder().encodeToString(Files.readAllBytes(Path.of(testImagePath)))) + assertThat(request.body) + .extracting("model", "messages") + .containsExactly( + "gpt-4-vision-preview", + listOf( + mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT), + mapOf("role" to "user", "content" to listOf( + mapOf( + "type" to "image_url", + "image_url" to mapOf("url" to testImageUrl)), + mapOf("type" to "text", "text" to "TEST_MESSAGE") + )))) + } catch (e: IOException) { + throw RuntimeException(e) + } + listOf( + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) + }) + + panel.sendMessage(message) + + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + } + val encodingManager = EncodingManager.getInstance() + assertThat(panel.tokenDetails).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens") + .containsExactly( + encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), + encodingManager.countTokens(message.prompt), + 0, + 0) + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response", "imageFilePath") + .containsExactly( + message.id, + message.prompt, + message.response, + message.imageFilePath) + } + + fun testFixCompileErrorsWithOpenAIService() { + project.putUserData( + CodeGPTKeys.SELECTED_FILES, listOf( + ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"), + ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"), + ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3"))) + useOpenAIService() + ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT + val message = Message("TEST_MESSAGE") + message.userMessage = "TEST_MESSAGE" + message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectOpenAI(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/v1/chat/completions") + assertThat(request.method).isEqualTo("POST") + assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY") + assertThat(request.body) + .extracting( + "model", + "messages") + .containsExactly( + "gpt-4", + listOf( + mapOf("role" to "system", "content" to FIX_COMPILE_ERRORS_SYSTEM_PROMPT), + mapOf("role" to "user", "content" to """ + Use the following context to answer question at the end: + + File Path: TEST_FILE_PATH_1 + File Content: + ```TEST_FILE_NAME_1 + TEST_FILE_CONTENT_1 + ``` + + File Path: TEST_FILE_PATH_2 + File Content: + ```TEST_FILE_NAME_2 + TEST_FILE_CONTENT_2 + ``` + + File Path: TEST_FILE_PATH_3 + File Content: + ```TEST_FILE_NAME_3 + TEST_FILE_CONTENT_3 + ``` + + Question: TEST_MESSAGE""".trimIndent()))) + listOf( + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))), + jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!"))))) + }) + + panel.sendMessage(message, ConversationType.FIX_COMPILE_ERRORS) + + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + } + val encodingManager = EncodingManager.getInstance() + assertThat(panel.tokenDetails).extracting( + "systemPromptTokens", + "conversationTokens", + "userPromptTokens", + "highlightedTokens") + .containsExactly( + encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT), + encodingManager.countTokens(message.prompt), + 0, + 0) + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response", "referencedFilePaths") + .containsExactly( + message.id, + message.prompt, + message.response, + listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")) + } + + fun testSendingLlamaMessage() { + useLlamaService() + val configurationState = ConfigurationSettings.getCurrentState() + configurationState.systemPrompt = COMPLETION_SYSTEM_PROMPT + configurationState.maxTokens = 1000 + configurationState.temperature = 0.1 + val llamaSettings = LlamaSettings.getCurrentState() + llamaSettings.isUseCustomModel = false + llamaSettings.huggingFaceModel = HuggingFaceModel.CODE_LLAMA_7B_Q4 + llamaSettings.topK = 30 + llamaSettings.topP = 0.8 + llamaSettings.minP = 0.03 + llamaSettings.repeatPenalty = 1.3 + val message = Message("TEST_PROMPT") + val conversation = ConversationService.getInstance().startConversation() + val panel = ChatToolWindowTabPanel(project, conversation) + expectLlama(StreamHttpExchange { request: RequestEntity -> + assertThat(request.uri.path).isEqualTo("/completion") + assertThat(request.body) + .extracting( + "prompt", + "n_predict", + "stream", + "temperature", + "top_k", + "top_p", + "min_p", + "repeat_penalty") + .containsExactly( + LLAMA.buildPrompt( + COMPLETION_SYSTEM_PROMPT, + "TEST_PROMPT", + conversation.messages), + configurationState.maxTokens, + true, + configurationState.temperature, + llamaSettings.topK, + llamaSettings.topP, + llamaSettings.minP, + llamaSettings.repeatPenalty) + listOf( + jsonMapResponse("content", "Hel"), + jsonMapResponse("content", "lo!"), + jsonMapResponse( + e("content", ""), + e("stop", true))) + }) + + panel.sendMessage(message, ConversationType.DEFAULT) + + waitExpecting { + val messages = conversation.messages + messages.isNotEmpty() && "Hello!" == messages[0].response + } + assertThat(panel.conversation) + .isNotNull() + .extracting("id", "model", "clientCode", "discardTokenLimit") + .containsExactly( + conversation.id, + conversation.model, + conversation.clientCode, + false) + val messages = panel.conversation.messages + assertThat(messages).hasSize(1) + assertThat(messages[0]) + .extracting("id", "prompt", "response") + .containsExactly(message.id, message.prompt, message.response) + } +} diff --git a/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.kt b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.kt new file mode 100644 index 00000000..adbaaf97 --- /dev/null +++ b/src/test/kotlin/ee/carlrobert/codegpt/toolwindow/chat/ChatToolWindowTabbedPaneTest.kt @@ -0,0 +1,50 @@ +package ee.carlrobert.codegpt.toolwindow.chat + +import com.intellij.openapi.util.Disposer +import com.intellij.testFramework.fixtures.BasePlatformTestCase +import ee.carlrobert.codegpt.conversations.ConversationService +import ee.carlrobert.codegpt.conversations.message.Message +import org.assertj.core.api.Assertions.assertThat + +class ChatToolWindowTabbedPaneTest : BasePlatformTestCase() { + + fun testClearAllTabs() { + val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable()) + tabbedPane.addNewTab(createNewTabPanel()) + + tabbedPane.clearAll() + + assertThat(tabbedPane.activeTabMapping).isEmpty() + } + + + fun testAddingNewTabs() { + val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable()) + + tabbedPane.addNewTab(createNewTabPanel()) + tabbedPane.addNewTab(createNewTabPanel()) + tabbedPane.addNewTab(createNewTabPanel()) + + assertThat(tabbedPane.activeTabMapping.keys) + .containsExactly("Chat 1", "Chat 2", "Chat 3") + } + + fun testResetCurrentlyActiveTabPanel() { + val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable()) + val conversation = ConversationService.getInstance().startConversation() + conversation.addMessage(Message("TEST_PROMPT", "TEST_RESPONSE")) + tabbedPane.addNewTab(ChatToolWindowTabPanel(project, conversation)) + + tabbedPane.resetCurrentlyActiveTabPanel(project) + + val tabPanel = tabbedPane.activeTabMapping["Chat 1"] + assertThat(tabPanel!!.conversation.messages).isEmpty() + } + + private fun createNewTabPanel(): ChatToolWindowTabPanel { + return ChatToolWindowTabPanel( + project, + ConversationService.getInstance().startConversation() + ) + } +} diff --git a/src/test/java/ee/carlrobert/codegpt/util/MarkdownUtilTest.java b/src/test/kotlin/ee/carlrobert/codegpt/util/MarkdownUtilTest.kt similarity index 73% rename from src/test/java/ee/carlrobert/codegpt/util/MarkdownUtilTest.java rename to src/test/kotlin/ee/carlrobert/codegpt/util/MarkdownUtilTest.kt index 32917b01..f1b60c95 100644 --- a/src/test/java/ee/carlrobert/codegpt/util/MarkdownUtilTest.java +++ b/src/test/kotlin/ee/carlrobert/codegpt/util/MarkdownUtilTest.kt @@ -1,14 +1,13 @@ -package ee.carlrobert.codegpt.util; +package ee.carlrobert.codegpt.util -import static org.assertj.core.api.Assertions.assertThat; +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test -import org.junit.Test; - -public class MarkdownUtilTest { +class MarkdownUtilTest { @Test - public void shouldExtractMarkdownCodeBlocks() { - String testInput = """ + fun shouldExtractMarkdownCodeBlocks() { + val testInput = """ **C++ Code Block** ```cpp #include @@ -29,60 +28,68 @@ public class MarkdownUtilTest { ``` 1. We define a **public class** called **Main**. 2. We define the **main** method which is the entry point of the program. - """; - var result = MarkdownUtil.splitCodeBlocks(testInput); + """.trimIndent() + + val result = MarkdownUtil.splitCodeBlocks(testInput) assertThat(result).containsExactly(""" **C++ Code Block** - """, """ + + """.trimIndent(), """ ```cpp #include int main() { return 0; } - ```""", """ + ```""".trimIndent(), """ 1. We include the **iostream** header file. 2. We define the main function. **Java Code Block** - """, """ + + """.trimIndent(), """ ```java public class Main { public static void main(String[] args) { } } - ```""", """ + ```""".trimIndent(), """ 1. We define a **public class** called **Main**. 2. We define the **main** method which is the entry point of the program. - """); + + """.trimIndent()) } @Test - public void shouldExtractMarkdownWithoutCode() { - String testInput = """ + fun shouldExtractMarkdownWithoutCode() { + val testInput = """ **C++ Code Block** 1. We include the **iostream** header file. 2. We define the main function. - """; - var result = MarkdownUtil.splitCodeBlocks(testInput); + + """.trimIndent() + + val result = MarkdownUtil.splitCodeBlocks(testInput) assertThat(result).containsExactly(""" **C++ Code Block** 1. We include the **iostream** header file. 2. We define the main function. - """); + + + """.trimIndent()) } @Test - public void shouldExtractMarkdownCodeOnly() { - String testInput = """ + fun shouldExtractMarkdownCodeOnly() { + val testInput = """ ```cpp #include @@ -96,9 +103,10 @@ public class MarkdownUtilTest { } } ``` - """; - var result = MarkdownUtil.splitCodeBlocks(testInput); + """.trimIndent() + + val result = MarkdownUtil.splitCodeBlocks(testInput) assertThat(result).containsExactly(""" ```cpp @@ -107,12 +115,12 @@ public class MarkdownUtilTest { int main() { return 0; } - ```""", """ + ```""".trimIndent(), """ ```java public class Main { public static void main(String[] args) { } } - ```"""); + ```""".trimIndent()) } } diff --git a/src/test/kotlin/testsupport/IntegrationTest.kt b/src/test/kotlin/testsupport/IntegrationTest.kt new file mode 100644 index 00000000..28da2eca --- /dev/null +++ b/src/test/kotlin/testsupport/IntegrationTest.kt @@ -0,0 +1,33 @@ +package testsupport + +import com.intellij.openapi.util.Key +import com.intellij.testFramework.fixtures.BasePlatformTestCase +import ee.carlrobert.codegpt.CodeGPTKeys +import ee.carlrobert.llm.client.mixin.ExternalServiceTestMixin +import testsupport.mixin.ShortcutsTestMixin + +open class IntegrationTest : BasePlatformTestCase(), ExternalServiceTestMixin, ShortcutsTestMixin { + + @Throws(Exception::class) + override fun tearDown() { + ExternalServiceTestMixin.clearAll() + clearKeys() + super.tearDown() + } + + private fun clearKeys() { + putUserData(CodeGPTKeys.SELECTED_FILES, emptyList()) + putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, "") + putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, "") + } + + private fun putUserData(key: Key, value: T) { + project.putUserData(key, value) + } + + companion object { + init { + ExternalServiceTestMixin.init() + } + } +} diff --git a/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt b/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt new file mode 100644 index 00000000..284a7ba9 --- /dev/null +++ b/src/test/kotlin/testsupport/mixin/ShortcutsTestMixin.kt @@ -0,0 +1,52 @@ +package testsupport.mixin + +import com.intellij.testFramework.PlatformTestUtil +import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_OPENAI_API_KEY +import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY +import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential +import ee.carlrobert.codegpt.settings.GeneralSettings +import ee.carlrobert.codegpt.settings.service.ServiceType +import ee.carlrobert.codegpt.settings.service.azure.AzureSettings +import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings +import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings +import java.util.function.BooleanSupplier + +interface ShortcutsTestMixin { + + fun useOpenAIService() { + useOpenAIService("gpt-4") + } + + fun useOpenAIService(model: String? = "gpt-4") { + GeneralSettings.getCurrentState().selectedService = ServiceType.OPENAI + setCredential(OPENAI_API_KEY, "TEST_API_KEY") + OpenAISettings.getCurrentState().model = model + } + + fun useAzureService() { + GeneralSettings.getCurrentState().selectedService = ServiceType.AZURE + setCredential(AZURE_OPENAI_API_KEY, "TEST_API_KEY") + val azureSettings = AzureSettings.getCurrentState() + azureSettings.resourceName = "TEST_RESOURCE_NAME" + azureSettings.apiVersion = "TEST_API_VERSION" + azureSettings.deploymentId = "TEST_DEPLOYMENT_ID" + } + + fun useYouService() { + GeneralSettings.getCurrentState().selectedService = ServiceType.YOU + } + + fun useLlamaService() { + GeneralSettings.getCurrentState().selectedService = ServiceType.LLAMA_CPP + LlamaSettings.getCurrentState().serverPort = null + } + + fun waitExpecting(condition: BooleanSupplier?) { + PlatformTestUtil.waitWithEventsDispatching( + "Waiting for message response timed out or did not meet expected conditions", + condition!!, + 5 + ) + } + +}