chore: Convert Java tests to Kotlin (#447)

This commit is contained in:
Rene Leonhardt 2024-04-11 11:03:31 +02:00 committed by GitHub
parent 6fb0b8d30c
commit 0cdd5096ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1276 additions and 1271 deletions

View file

@ -1,43 +0,0 @@
package ee.carlrobert.codegpt.codecompletions;
import static ee.carlrobert.codegpt.CodeGPTKeys.PREVIOUS_INLAY_TEXT;
import static ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate.LLAMA;
import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent;
import static ee.carlrobert.llm.client.util.JSONUtil.e;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
import static org.assertj.core.api.Assertions.assertThat;
import com.intellij.openapi.editor.VisualPosition;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import java.util.List;
import testsupport.IntegrationTest;
public class CodeCompletionServiceTest extends IntegrationTest {
private final VisualPosition cursorPosition = new VisualPosition(3, 0);
public void testFetchCodeCompletionLlama() {
useLlamaService();
LlamaSettings.getCurrentState().setCodeCompletionsEnabled(true);
myFixture.configureByText(
"CompletionTest.java",
getResourceContent("/codecompletions/code-completion-file.txt"));
myFixture.getEditor().getCaretModel().moveToVisualPosition(cursorPosition);
var expectedCompletion = "TEST_OUTPUT";
var prefix = "z".repeat(245) + "\n[INPUT]\nc"; // 128 tokens
var suffix = "\n[\\INPUT]\n" + "z".repeat(247); // 128 tokens
expectLlama((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/completion");
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getBody())
.extracting("prompt")
.isEqualTo(LLAMA.buildPrompt(prefix, suffix));
return List.of(jsonMapResponse(e("content", expectedCompletion), e("stop", true)));
});
myFixture.type('c');
waitExpecting(() -> "TEST_OUTPUT".equals(PREVIOUS_INLAY_TEXT.get(myFixture.getEditor())));
}
}

View file

@ -1,154 +0,0 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.groups.Tuple.tuple;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
import testsupport.IntegrationTest;
public class CompletionRequestProviderTest extends IntegrationTest {
public void testChatCompletionRequestWithSystemPromptOverride() {
CredentialsStore.INSTANCE.setCredential(CredentialKey.OPENAI_API_KEY, "TEST_API_KEY");
ConfigurationSettings.getCurrentState().setSystemPrompt("TEST_SYSTEM_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
var firstMessage = createDummyMessage(500);
var secondMessage = createDummyMessage(250);
conversation.addMessage(firstMessage);
conversation.addMessage(secondMessage);
var request = new CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.getCode(),
new CallParameters(
conversation,
ConversationType.DEFAULT,
new Message("TEST_CHAT_COMPLETION_PROMPT"),
false));
assertThat(request.getMessages())
.extracting("role", "content")
.containsExactly(
tuple("system", "TEST_SYSTEM_PROMPT"),
tuple("user", "TEST_PROMPT"),
tuple("assistant", firstMessage.getResponse()),
tuple("user", "TEST_PROMPT"),
tuple("assistant", secondMessage.getResponse()),
tuple("user", "TEST_CHAT_COMPLETION_PROMPT"));
}
public void testChatCompletionRequestWithoutSystemPromptOverride() {
var conversation = ConversationService.getInstance().startConversation();
var firstMessage = createDummyMessage(500);
var secondMessage = createDummyMessage(250);
conversation.addMessage(firstMessage);
conversation.addMessage(secondMessage);
var request = new CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.getCode(),
new CallParameters(
conversation,
ConversationType.DEFAULT,
new Message("TEST_CHAT_COMPLETION_PROMPT"),
false));
assertThat(request.getMessages())
.extracting("role", "content")
.containsExactly(
tuple("system", COMPLETION_SYSTEM_PROMPT),
tuple("user", "TEST_PROMPT"),
tuple("assistant", firstMessage.getResponse()),
tuple("user", "TEST_PROMPT"),
tuple("assistant", secondMessage.getResponse()),
tuple("user", "TEST_CHAT_COMPLETION_PROMPT"));
}
public void testChatCompletionRequestRetry() {
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
var conversation = ConversationService.getInstance().startConversation();
var firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500);
var secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250);
conversation.addMessage(firstMessage);
conversation.addMessage(secondMessage);
var request = new CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.getCode(),
new CallParameters(
conversation,
ConversationType.DEFAULT,
secondMessage,
true));
assertThat(request.getMessages())
.extracting("role", "content")
.containsExactly(
tuple("system", COMPLETION_SYSTEM_PROMPT),
tuple("user", "FIRST_TEST_PROMPT"),
tuple("assistant", firstMessage.getResponse()),
tuple("user", "SECOND_TEST_PROMPT"));
}
public void testReducedChatCompletionRequest() {
var conversation = ConversationService.getInstance().startConversation();
conversation.addMessage(createDummyMessage(50));
conversation.addMessage(createDummyMessage(100));
conversation.addMessage(createDummyMessage(150));
conversation.addMessage(createDummyMessage(1000));
var remainingMessage = createDummyMessage(2000);
conversation.addMessage(remainingMessage);
conversation.discardTokenLimits();
var request = new CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.getCode(),
new CallParameters(
conversation,
ConversationType.DEFAULT,
new Message("TEST_CHAT_COMPLETION_PROMPT"),
false));
assertThat(request.getMessages())
.extracting("role", "content")
.containsExactly(
tuple("system", COMPLETION_SYSTEM_PROMPT),
tuple("user", "TEST_PROMPT"),
tuple("assistant", remainingMessage.getResponse()),
tuple("user", "TEST_CHAT_COMPLETION_PROMPT"));
}
public void testTotalUsageExceededException() {
var conversation = ConversationService.getInstance().startConversation();
conversation.addMessage(createDummyMessage(1500));
conversation.addMessage(createDummyMessage(1500));
conversation.addMessage(createDummyMessage(1500));
assertThrows(TotalUsageExceededException.class,
() -> new CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.getCode(),
new CallParameters(
conversation,
ConversationType.DEFAULT,
createDummyMessage(100),
false)));
}
private Message createDummyMessage(int tokenSize) {
return createDummyMessage("TEST_PROMPT", tokenSize);
}
private Message createDummyMessage(String prompt, int tokenSize) {
var message = new Message(prompt);
// 'zz' = 1 token, prompt = 6 tokens, 7 tokens per message (GPT-3),
message.setResponse("zz".repeat((tokenSize) - 6 - 7));
return message;
}
}

View file

@ -1,183 +0,0 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
import static ee.carlrobert.llm.client.util.JSONUtil.e;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
import static org.apache.http.HttpHeaders.AUTHORIZATION;
import static org.assertj.core.api.Assertions.assertThat;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import java.util.List;
import java.util.Map;
import testsupport.IntegrationTest;
public class DefaultCompletionRequestHandlerTest extends IntegrationTest {
public void testOpenAIChatCompletionCall() {
useOpenAIService();
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
var requestHandler = new CompletionRequestHandler(getRequestEventListener(message));
expectOpenAI((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
assertThat(request.getBody())
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
List.of(
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
Map.of("role", "user", "content", "TEST_PROMPT")));
return List.of(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
});
requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false));
waitExpecting(() -> "Hello!".equals(message.getResponse()));
}
public void testAzureChatCompletionCall() {
useAzureService();
var conversationService = ConversationService.getInstance();
var prevMessage = new Message("TEST_PREV_PROMPT");
prevMessage.setResponse("TEST_PREV_RESPONSE");
var conversation = conversationService.startConversation();
conversation.addMessage(prevMessage);
conversationService.saveConversation(conversation);
expectAzure((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo(
"/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions");
assertThat(request.getUri().getQuery()).isEqualTo("api-version=TEST_API_VERSION");
assertThat(request.getHeaders().get("Api-key").get(0)).isEqualTo("TEST_API_KEY");
assertThat(request.getHeaders().get("X-llm-application-tag").get(0)).isEqualTo("codegpt");
assertThat(request.getBody())
.extracting("messages")
.isEqualTo(
List.of(
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
Map.of("role", "user", "content", "TEST_PREV_PROMPT"),
Map.of("role", "assistant", "content", "TEST_PREV_RESPONSE"),
Map.of("role", "user", "content", "TEST_PROMPT")));
return List.of(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
});
var message = new Message("TEST_PROMPT");
var requestHandler = new CompletionRequestHandler(getRequestEventListener(message));
requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false));
waitExpecting(() -> "Hello!".equals(message.getResponse()));
}
public void testYouChatCompletionCall() {
useYouService();
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
conversation.addMessage(new Message("Ping", "Pong"));
var requestHandler = new CompletionRequestHandler(getRequestEventListener(message));
expectYou((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch");
assertThat(request.getMethod()).isEqualTo("GET");
assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch");
assertThat(request.getUri().getQuery()).isEqualTo(
"q=TEST_PROMPT&"
+ "page=1&"
+ "cfr=CodeGPT&"
+ "count=10&"
+ "safeSearch=WebPages,Translations,TimeZone,Computation,RelatedSearches&"
+ "domain=youchat&"
+ "selectedChatMode=default&"
+ "chat=[{\"question\":\"Ping\",\"answer\":\"Pong\"}]&"
+ "utm_source=ide&"
+ "utm_medium=jetbrains&"
+ "utm_campaign=" + CodeGPTPlugin.getVersion() + "&"
+ "utm_content=CodeGPT");
assertThat(request.getHeaders())
.flatExtracting("Accept", "Connection", "User-agent", "Cookie")
.containsExactly(
"text/event-stream",
"Keep-Alive",
"youide CodeGPT",
"safesearch_guest=Moderate; "
+ "youpro_subscription=true; "
+ "you_subscription=free; "
+ "stytch_session=; "
+ "ydc_stytch_session=; "
+ "stytch_session_jwt=; "
+ "ydc_stytch_session_jwt=; "
+ "eg4=false; "
+ "__cf_bm=aN2b3pQMH8XADeMB7bg9s1bJ_bfXBcCHophfOGRg6g0-1693601599-0-"
+ "AWIt5Mr4Y3xQI4mIJ1lSf4+vijWKDobrty8OopDeBxY+NABe0MRFidF3dCUoWjRt8"
+ "SVMvBZPI3zkOgcRs7Mz3yazd7f7c58HwW5Xg9jdBjNg;");
return List.of(
jsonMapResponse("youChatToken", "Hel"),
jsonMapResponse("youChatToken", "lo"),
jsonMapResponse("youChatToken", "!"));
});
requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false));
waitExpecting(() -> "Hello!".equals(message.getResponse()));
}
public void testLlamaChatCompletionCall() {
useLlamaService();
ConfigurationSettings.getCurrentState().setMaxTokens(99);
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
conversation.addMessage(new Message("Ping", "Pong"));
var requestHandler = new CompletionRequestHandler(getRequestEventListener(message));
expectLlama((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/completion");
assertThat(request.getBody())
.extracting(
"prompt",
"n_predict",
"stream")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_PROMPT",
conversation.getMessages()),
99,
true);
return List.of(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)));
});
requestHandler.call(new CallParameters(conversation, ConversationType.DEFAULT, message, false));
waitExpecting(() -> "Hello!".equals(message.getResponse()));
}
private CompletionResponseEventListener getRequestEventListener(Message message) {
return new CompletionResponseEventListener() {
@Override
public void handleCompleted(String fullMessage, CallParameters callParameters) {
message.setResponse(fullMessage);
}
};
}
}

View file

@ -1,145 +0,0 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA;
import static org.assertj.core.api.Assertions.assertThat;
import ee.carlrobert.codegpt.conversations.message.Message;
import java.util.List;
import org.junit.Test;
public class PromptTemplateTest {
private static final String SYSTEM_PROMPT = "TEST_SYSTEM_PROMPT";
private static final String USER_PROMPT = "TEST_USER_PROMPT";
private static final List<Message> HISTORY = List.of(
new Message("TEST_PREV_PROMPT_1", "TEST_PREV_RESPONSE_1"),
new Message("TEST_PREV_PROMPT_2", "TEST_PREV_RESPONSE_2"));
@Test
public void shouldBuildLlamaPromptWithHistory() {
var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
assertThat(prompt).isEqualTo("""
<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>
[INST]TEST_PREV_PROMPT_1[/INST]
TEST_PREV_RESPONSE_1
[INST]TEST_PREV_PROMPT_2[/INST]
TEST_PREV_RESPONSE_2
[INST]TEST_USER_PROMPT[/INST]""");
}
@Test
public void shouldBuildLlamaPromptWithoutHistory() {
var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
assertThat(prompt).isEqualTo("""
<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>
[INST]TEST_USER_PROMPT[/INST]""");
}
@Test
public void shouldBuildAlpacaPromptWithHistory() {
var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
assertThat(prompt).isEqualTo("""
Below is an instruction that describes a task. \
Write a response that appropriately completes the request.
### Instruction
TEST_PREV_PROMPT_1
### Response:
TEST_PREV_RESPONSE_1
### Instruction
TEST_PREV_PROMPT_2
### Response:
TEST_PREV_RESPONSE_2
### Instruction
TEST_USER_PROMPT
### Response:
""");
}
@Test
public void shouldBuildAlpacaPromptWithoutHistory() {
var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
assertThat(prompt).isEqualTo("""
Below is an instruction that describes a task. \
Write a response that appropriately completes the request.
### Instruction
TEST_USER_PROMPT
### Response:
""");
}
@Test
public void shouldBuildChatMLPromptWithHistory() {
var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
assertThat(prompt).isEqualTo("""
<|im_start|>system
TEST_SYSTEM_PROMPT<|im_end|>
<|im_start|>user
TEST_PREV_PROMPT_1<|im_end|>
<|im_start|>assistant
TEST_PREV_RESPONSE_1<|im_end|>
<|im_start|>user
TEST_PREV_PROMPT_2<|im_end|>
<|im_start|>assistant
TEST_PREV_RESPONSE_2<|im_end|>
<|im_start|>user
TEST_USER_PROMPT<|im_end|>"""
);
}
@Test
public void shouldBuildChatMLPromptWithoutHistory() {
var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
assertThat(prompt).isEqualTo("""
<|im_start|>system
TEST_SYSTEM_PROMPT<|im_end|>
<|im_start|>user
TEST_USER_PROMPT<|im_end|>""");
}
@Test
public void shouldBuildToRAPromptWithHistory() {
var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
assertThat(prompt).isEqualTo("""
<|user|>
TEST_PREV_PROMPT_1
<|assistant|>
TEST_PREV_RESPONSE_1
<|user|>
TEST_PREV_PROMPT_2
<|assistant|>
TEST_PREV_RESPONSE_2
<|user|>
TEST_USER_PROMPT
<|assistant|>"""
);
}
@Test
public void shouldBuildToRAPromptWithoutHistory() {
var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
assertThat(prompt).isEqualTo("""
<|user|>
TEST_USER_PROMPT
<|assistant|>"""
);
}
}

View file

@ -1,90 +0,0 @@
package ee.carlrobert.codegpt.conversations;
import static org.assertj.core.api.Assertions.assertThat;
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
public class ConversationsStateTest extends BasePlatformTestCase {
public void testStartNewDefaultConversation() {
GeneralSettings.getCurrentState().setSelectedService(ServiceType.OPENAI);
OpenAISettings.getCurrentState().setModel(OpenAIChatCompletionModel.GPT_3_5.getCode());
var conversation = ConversationService.getInstance().startConversation();
assertThat(conversation).isEqualTo(ConversationsState.getCurrentConversation());
assertThat(conversation)
.extracting("clientCode", "model")
.containsExactly("chat.completion", "gpt-3.5-turbo");
}
public void testSaveConversation() {
var service = ConversationService.getInstance();
var conversation = service.createConversation("chat.completion");
service.addConversation(conversation);
var message = new Message("TEST_PROMPT");
message.setResponse("TEST_RESPONSE");
conversation.addMessage(message);
service.saveConversation(conversation);
var currentConversation = ConversationsState.getCurrentConversation();
assertThat(currentConversation).isNotNull();
assertThat(currentConversation.getMessages())
.flatExtracting("prompt", "response")
.containsExactly("TEST_PROMPT", "TEST_RESPONSE");
}
public void testGetPreviousConversation() {
var service = ConversationService.getInstance();
var firstConversation = service.startConversation();
service.startConversation();
var previousConversation = service.getPreviousConversation();
assertThat(previousConversation.isPresent()).isTrue();
assertThat(previousConversation.get()).isEqualTo(firstConversation);
}
public void testGetNextConversation() {
var service = ConversationService.getInstance();
var firstConversation = service.startConversation();
var secondConversation = service.startConversation();
ConversationsState.getInstance().setCurrentConversation(firstConversation);
var nextConversation = service.getNextConversation();
assertThat(nextConversation.isPresent()).isTrue();
assertThat(nextConversation.get()).isEqualTo(secondConversation);
}
public void testDeleteSelectedConversation() {
var service = ConversationService.getInstance();
var firstConversation = service.startConversation();
service.startConversation();
service.deleteSelectedConversation();
assertThat(ConversationsState.getCurrentConversation()).isEqualTo(firstConversation);
assertThat(service.getSortedConversations().size()).isEqualTo(1);
assertThat(service.getSortedConversations())
.extracting("id")
.containsExactly(firstConversation.getId());
}
public void testClearAllConversations() {
var service = ConversationService.getInstance();
service.startConversation();
service.startConversation();
service.clearAll();
assertThat(ConversationsState.getCurrentConversation()).isNull();
assertThat(service.getSortedConversations().size()).isEqualTo(0);
}
}

View file

@ -1,81 +0,0 @@
package ee.carlrobert.codegpt.settings.state;
import static ee.carlrobert.codegpt.completions.HuggingFaceModel.CODE_LLAMA_7B_Q3;
import static org.assertj.core.api.Assertions.assertThat;
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
public class GeneralSettingsTest extends BasePlatformTestCase {
public void testOpenAISettingsSync() {
var openAISettings = OpenAISettings.getCurrentState();
openAISettings.setModel("gpt-3.5-turbo");
var conversation = new Conversation();
conversation.setModel("gpt-4");
conversation.setClientCode("chat.completion");
var settings = GeneralSettings.getInstance();
settings.sync(conversation);
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.OPENAI);
assertThat(openAISettings.getModel()).isEqualTo("gpt-4");
}
public void testAzureSettingsSync() {
var settings = GeneralSettings.getInstance();
var conversation = new Conversation();
conversation.setModel("gpt-4");
conversation.setClientCode("azure.chat.completion");
settings.sync(conversation);
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.AZURE);
}
public void testYouSettingsSync() {
var settings = GeneralSettings.getInstance();
var conversation = new Conversation();
conversation.setModel("YouCode");
conversation.setClientCode("you.chat.completion");
settings.sync(conversation);
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.YOU);
}
public void testLlamaSettingsModelPathSync() {
var llamaSettings = LlamaSettings.getCurrentState();
llamaSettings.setHuggingFaceModel(HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3);
var conversation = new Conversation();
conversation.setModel("TEST_LLAMA_MODEL_PATH");
conversation.setClientCode("llama.chat.completion");
var settings = GeneralSettings.getInstance();
settings.sync(conversation);
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.LLAMA_CPP);
assertThat(llamaSettings.getCustomLlamaModelPath()).isEqualTo("TEST_LLAMA_MODEL_PATH");
assertThat(llamaSettings.isUseCustomModel()).isTrue();
}
public void testLlamaSettingsHuggingFaceModelSync() {
var llamaSettings = LlamaSettings.getCurrentState();
llamaSettings.setHuggingFaceModel(HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3);
var conversation = new Conversation();
conversation.setModel("CODE_LLAMA_7B_Q3");
conversation.setClientCode("llama.chat.completion");
var settings = GeneralSettings.getInstance();
settings.sync(conversation);
assertThat(settings.getState().getSelectedService()).isEqualTo(ServiceType.LLAMA_CPP);
assertThat(llamaSettings.getHuggingFaceModel()).isEqualTo(CODE_LLAMA_7B_Q3);
assertThat(llamaSettings.isUseCustomModel()).isFalse();
}
}

View file

@ -1,415 +0,0 @@
package ee.carlrobert.codegpt.toolwindow.chat;
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
import static ee.carlrobert.llm.client.util.JSONUtil.e;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
import static java.util.Objects.requireNonNull;
import static org.apache.http.HttpHeaders.AUTHORIZATION;
import static org.assertj.core.api.Assertions.assertThat;
import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.ReferencedFile;
import ee.carlrobert.codegpt.completions.ConversationType;
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import testsupport.IntegrationTest;
public class ChatToolWindowTabPanelTest extends IntegrationTest {
public void testSendingOpenAIMessage() {
useOpenAIService();
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
var message = new Message("Hello!");
var conversation = ConversationService.getInstance().startConversation();
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
expectOpenAI((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
assertThat(request.getBody())
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
List.of(
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
Map.of("role", "user", "content", "Hello!")));
return List.of(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
});
panel.sendMessage(message);
waitExpecting(() -> {
var messages = conversation.getMessages();
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
});
var encodingManager = EncodingManager.getInstance();
assertThat(panel.getTokenDetails()).extracting(
"systemPromptTokens",
"conversationTokens",
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens(message.getPrompt()),
0,
0);
assertThat(panel.getConversation())
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.getId(),
conversation.getModel(),
conversation.getClientCode(),
false);
var messages = panel.getConversation().getMessages();
assertThat(messages.size()).isOne();
assertThat(messages.get(0))
.extracting("id", "prompt", "response")
.containsExactly(message.getId(), message.getPrompt(), message.getResponse());
}
public void testSendingOpenAIMessageWithReferencedContext() {
getProject().putUserData(CodeGPTKeys.SELECTED_FILES, List.of(
new ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"),
new ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
new ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")));
useOpenAIService();
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
var message = new Message("TEST_MESSAGE");
message.setUserMessage("TEST_MESSAGE");
message.setReferencedFilePaths(
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
var conversation = ConversationService.getInstance().startConversation();
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
expectOpenAI((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
assertThat(request.getBody())
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
List.of(
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
Map.of("role", "user", "content",
"""
Use the following context to answer question at the end:
File Path: TEST_FILE_PATH_1
File Content:
```TEST_FILE_NAME_1
TEST_FILE_CONTENT_1
```
File Path: TEST_FILE_PATH_2
File Content:
```TEST_FILE_NAME_2
TEST_FILE_CONTENT_2
```
File Path: TEST_FILE_PATH_3
File Content:
```TEST_FILE_NAME_3
TEST_FILE_CONTENT_3
```
Question: TEST_MESSAGE""")));
return List.of(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
});
panel.sendMessage(message);
waitExpecting(() -> {
var messages = conversation.getMessages();
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
});
var encodingManager = EncodingManager.getInstance();
assertThat(panel.getTokenDetails()).extracting(
"systemPromptTokens",
"conversationTokens",
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens(message.getPrompt()),
0,
0);
assertThat(panel.getConversation())
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.getId(),
conversation.getModel(),
conversation.getClientCode(),
false);
var messages = panel.getConversation().getMessages();
assertThat(messages.size()).isOne();
assertThat(messages.get(0))
.extracting("id", "prompt", "response", "referencedFilePaths")
.containsExactly(
message.getId(),
message.getPrompt(),
message.getResponse(),
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
}
public void testSendingOpenAIMessageWithImage() {
var testImagePath = requireNonNull(getClass().getResource("/images/test-image.png")).getPath();
getProject().putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath);
useOpenAIService("gpt-4-vision-preview");
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
var message = new Message("TEST_MESSAGE");
var conversation = ConversationService.getInstance().startConversation();
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
expectOpenAI((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
try {
var testImageUrl = "data:image/png;base64,"
+ Base64.getEncoder().encodeToString(Files.readAllBytes(Path.of(testImagePath)));
assertThat(request.getBody())
.extracting("model", "messages")
.containsExactly(
"gpt-4-vision-preview",
List.of(
Map.of("role", "system", "content", COMPLETION_SYSTEM_PROMPT),
Map.of("role", "user", "content", List.of(
Map.of(
"type", "image_url",
"image_url", Map.of("url", testImageUrl)),
Map.of("type", "text", "text", "TEST_MESSAGE")
))));
} catch (IOException e) {
throw new RuntimeException(e);
}
return List.of(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
});
panel.sendMessage(message);
waitExpecting(() -> {
var messages = conversation.getMessages();
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
});
var encodingManager = EncodingManager.getInstance();
assertThat(panel.getTokenDetails()).extracting(
"systemPromptTokens",
"conversationTokens",
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens(message.getPrompt()),
0,
0);
assertThat(panel.getConversation())
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.getId(),
conversation.getModel(),
conversation.getClientCode(),
false);
var messages = panel.getConversation().getMessages();
assertThat(messages.size()).isOne();
assertThat(messages.get(0))
.extracting("id", "prompt", "response", "imageFilePath")
.containsExactly(
message.getId(),
message.getPrompt(),
message.getResponse(),
message.getImageFilePath());
}
public void testFixCompileErrorsWithOpenAIService() {
getProject().putUserData(CodeGPTKeys.SELECTED_FILES, List.of(
new ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"),
new ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
new ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")));
useOpenAIService();
ConfigurationSettings.getCurrentState().setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
var message = new Message("TEST_MESSAGE");
message.setUserMessage("TEST_MESSAGE");
message.setReferencedFilePaths(
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
var conversation = ConversationService.getInstance().startConversation();
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
expectOpenAI((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/v1/chat/completions");
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
assertThat(request.getBody())
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
List.of(
Map.of("role", "system", "content", FIX_COMPILE_ERRORS_SYSTEM_PROMPT),
Map.of("role", "user", "content",
"""
Use the following context to answer question at the end:
File Path: TEST_FILE_PATH_1
File Content:
```TEST_FILE_NAME_1
TEST_FILE_CONTENT_1
```
File Path: TEST_FILE_PATH_2
File Content:
```TEST_FILE_NAME_2
TEST_FILE_CONTENT_2
```
File Path: TEST_FILE_PATH_3
File Content:
```TEST_FILE_NAME_3
TEST_FILE_CONTENT_3
```
Question: TEST_MESSAGE""")));
return List.of(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))));
});
panel.sendMessage(message, ConversationType.FIX_COMPILE_ERRORS);
waitExpecting(() -> {
var messages = conversation.getMessages();
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
});
var encodingManager = EncodingManager.getInstance();
assertThat(panel.getTokenDetails()).extracting(
"systemPromptTokens",
"conversationTokens",
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens(message.getPrompt()),
0,
0);
assertThat(panel.getConversation())
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.getId(),
conversation.getModel(),
conversation.getClientCode(),
false);
var messages = panel.getConversation().getMessages();
assertThat(messages.size()).isOne();
assertThat(messages.get(0))
.extracting("id", "prompt", "response", "referencedFilePaths")
.containsExactly(
message.getId(),
message.getPrompt(),
message.getResponse(),
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
}
public void testSendingLlamaMessage() {
useLlamaService();
var configurationState = ConfigurationSettings.getCurrentState();
configurationState.setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
configurationState.setMaxTokens(1000);
configurationState.setTemperature(0.1);
var llamaSettings = LlamaSettings.getCurrentState();
llamaSettings.setUseCustomModel(false);
llamaSettings.setHuggingFaceModel(HuggingFaceModel.CODE_LLAMA_7B_Q4);
llamaSettings.setTopK(30);
llamaSettings.setTopP(0.8);
llamaSettings.setMinP(0.03);
llamaSettings.setRepeatPenalty(1.3);
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
var panel = new ChatToolWindowTabPanel(getProject(), conversation);
expectLlama((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/completion");
assertThat(request.getBody())
.extracting(
"prompt",
"n_predict",
"stream",
"temperature",
"top_k",
"top_p",
"min_p",
"repeat_penalty")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_PROMPT",
conversation.getMessages()),
configurationState.getMaxTokens(),
true,
configurationState.getTemperature(),
llamaSettings.getTopK(),
llamaSettings.getTopP(),
llamaSettings.getMinP(),
llamaSettings.getRepeatPenalty());
return List.of(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)));
});
panel.sendMessage(message, ConversationType.DEFAULT);
waitExpecting(() -> {
var messages = conversation.getMessages();
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
});
assertThat(panel.getConversation())
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.getId(),
conversation.getModel(),
conversation.getClientCode(),
false);
var messages = panel.getConversation().getMessages();
assertThat(messages.size()).isOne();
assertThat(messages.get(0))
.extracting("id", "prompt", "response")
.containsExactly(message.getId(), message.getPrompt(), message.getResponse());
}
}

View file

@ -1,50 +0,0 @@
package ee.carlrobert.codegpt.toolwindow.chat;
import static org.assertj.core.api.Assertions.assertThat;
import com.intellij.openapi.util.Disposer;
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
public class ChatToolWindowTabbedPaneTest extends BasePlatformTestCase {
public void testClearAllTabs() {
var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable());
tabbedPane.addNewTab(createNewTabPanel());
tabbedPane.clearAll();
assertThat(tabbedPane.getActiveTabMapping()).isEmpty();
}
public void testAddingNewTabs() {
var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable());
tabbedPane.addNewTab(createNewTabPanel());
tabbedPane.addNewTab(createNewTabPanel());
tabbedPane.addNewTab(createNewTabPanel());
assertThat(tabbedPane.getActiveTabMapping().keySet())
.containsExactly("Chat 1", "Chat 2", "Chat 3");
}
public void testResetCurrentlyActiveTabPanel() {
var tabbedPane = new ChatToolWindowTabbedPane(Disposer.newDisposable());
var conversation = ConversationService.getInstance().startConversation();
conversation.addMessage(new Message("TEST_PROMPT", "TEST_RESPONSE"));
tabbedPane.addNewTab(new ChatToolWindowTabPanel(getProject(), conversation));
tabbedPane.resetCurrentlyActiveTabPanel(getProject());
var tabPanel = tabbedPane.getActiveTabMapping().get("Chat 1");
assertThat(tabPanel.getConversation().getMessages()).isEmpty();
}
private ChatToolWindowTabPanel createNewTabPanel() {
return new ChatToolWindowTabPanel(
getProject(),
ConversationService.getInstance().startConversation());
}
}

View file

@ -1,34 +0,0 @@
package testsupport;
import com.intellij.openapi.util.Key;
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.llm.client.mixin.ExternalServiceTestMixin;
import java.util.Collections;
import testsupport.mixin.ShortcutsTestMixin;
public class IntegrationTest extends BasePlatformTestCase implements
ExternalServiceTestMixin,
ShortcutsTestMixin {
static {
ExternalServiceTestMixin.init();
}
@Override
protected void tearDown() throws Exception {
ExternalServiceTestMixin.clearAll();
clearKeys();
super.tearDown();
}
private void clearKeys() {
putUserData(CodeGPTKeys.SELECTED_FILES, Collections.emptyList());
putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, "");
putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, "");
}
private <T> void putUserData(Key<T> key, T value) {
getProject().putUserData(key, value);
}
}

View file

@ -1,51 +0,0 @@
package testsupport.mixin;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_OPENAI_API_KEY;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY;
import com.intellij.testFramework.PlatformTestUtil;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import java.util.function.BooleanSupplier;
public interface ShortcutsTestMixin {
default void useOpenAIService() {
useOpenAIService("gpt-4");
}
default void useOpenAIService(String model) {
GeneralSettings.getCurrentState().setSelectedService(ServiceType.OPENAI);
CredentialsStore.INSTANCE.setCredential(OPENAI_API_KEY, "TEST_API_KEY");
OpenAISettings.getCurrentState().setModel(model);
}
default void useAzureService() {
GeneralSettings.getCurrentState().setSelectedService(ServiceType.AZURE);
CredentialsStore.INSTANCE.setCredential(AZURE_OPENAI_API_KEY, "TEST_API_KEY");
var azureSettings = AzureSettings.getCurrentState();
azureSettings.setResourceName("TEST_RESOURCE_NAME");
azureSettings.setApiVersion("TEST_API_VERSION");
azureSettings.setDeploymentId("TEST_DEPLOYMENT_ID");
}
default void useYouService() {
GeneralSettings.getCurrentState().setSelectedService(ServiceType.YOU);
}
default void useLlamaService() {
GeneralSettings.getCurrentState().setSelectedService(ServiceType.LLAMA_CPP);
LlamaSettings.getCurrentState().setServerPort(null);
}
default void waitExpecting(BooleanSupplier condition) {
PlatformTestUtil.waitWithEventsDispatching(
"Waiting for message response timed out or did not meet expected conditions",
condition,
5);
}
}

View file

@ -0,0 +1,50 @@
package ee.carlrobert.codegpt.codecompletions
import com.intellij.openapi.editor.VisualPosition
import ee.carlrobert.codegpt.CodeGPTKeys.PREVIOUS_INLAY_TEXT
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent
import ee.carlrobert.llm.client.http.RequestEntity
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
import ee.carlrobert.llm.client.util.JSONUtil.e
import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
class CodeCompletionServiceTest : IntegrationTest() {
private val cursorPosition = VisualPosition(3, 0)
fun testFetchCodeCompletionLlama() {
useLlamaService()
LlamaSettings.getCurrentState().isCodeCompletionsEnabled = true
myFixture.configureByText(
"CompletionTest.java",
getResourceContent("/codecompletions/code-completion-file.txt")
)
myFixture.editor.caretModel.moveToVisualPosition(cursorPosition)
val expectedCompletion = "TEST_OUTPUT"
val prefix = """
${"z".repeat(245)}
[INPUT]
c
""".trimIndent() // 128 tokens
val suffix = """
[\INPUT]
${"z".repeat(247)}
""".trimIndent() // 128 tokens
expectLlama(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/completion")
assertThat(request.method).isEqualTo("POST")
assertThat(request.body)
.extracting("prompt")
.isEqualTo(InfillPromptTemplate.LLAMA.buildPrompt(prefix, suffix))
listOf(jsonMapResponse(e("content", expectedCompletion), e("stop", true)))
})
myFixture.type('c')
waitExpecting { "TEST_OUTPUT" == PREVIOUS_INLAY_TEXT[myFixture.editor] }
}
}

View file

@ -0,0 +1,152 @@
package ee.carlrobert.codegpt.completions
import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.groups.Tuple
import testsupport.IntegrationTest
class CompletionRequestProviderTest : IntegrationTest() {
fun testChatCompletionRequestWithSystemPromptOverride() {
setCredential(CredentialKey.OPENAI_API_KEY, "TEST_API_KEY")
ConfigurationSettings.getCurrentState().systemPrompt = "TEST_SYSTEM_PROMPT"
val conversation = ConversationService.getInstance().startConversation()
val firstMessage = createDummyMessage(500)
val secondMessage = createDummyMessage(250)
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
false))
assertThat(request.messages)
.extracting("role", "content")
.containsExactly(
Tuple.tuple("system", "TEST_SYSTEM_PROMPT"),
Tuple.tuple("user", "TEST_PROMPT"),
Tuple.tuple("assistant", firstMessage.response),
Tuple.tuple("user", "TEST_PROMPT"),
Tuple.tuple("assistant", secondMessage.response),
Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT"))
}
fun testChatCompletionRequestWithoutSystemPromptOverride() {
val conversation = ConversationService.getInstance().startConversation()
val firstMessage = createDummyMessage(500)
val secondMessage = createDummyMessage(250)
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
false))
assertThat(request.messages)
.extracting("role", "content")
.containsExactly(
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
Tuple.tuple("user", "TEST_PROMPT"),
Tuple.tuple("assistant", firstMessage.response),
Tuple.tuple("user", "TEST_PROMPT"),
Tuple.tuple("assistant", secondMessage.response),
Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT"))
}
fun testChatCompletionRequestRetry() {
ConfigurationSettings.getCurrentState().systemPrompt = CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
val conversation = ConversationService.getInstance().startConversation()
val firstMessage = createDummyMessage("FIRST_TEST_PROMPT", 500)
val secondMessage = createDummyMessage("SECOND_TEST_PROMPT", 250)
conversation.addMessage(firstMessage)
conversation.addMessage(secondMessage)
val request = CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
ConversationType.DEFAULT,
secondMessage,
true))
assertThat(request.messages)
.extracting("role", "content")
.containsExactly(
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
Tuple.tuple("user", "FIRST_TEST_PROMPT"),
Tuple.tuple("assistant", firstMessage.response),
Tuple.tuple("user", "SECOND_TEST_PROMPT"))
}
fun testReducedChatCompletionRequest() {
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(createDummyMessage(50))
conversation.addMessage(createDummyMessage(100))
conversation.addMessage(createDummyMessage(150))
conversation.addMessage(createDummyMessage(1000))
val remainingMessage = createDummyMessage(2000)
conversation.addMessage(remainingMessage)
conversation.discardTokenLimits()
val request = CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
ConversationType.DEFAULT,
Message("TEST_CHAT_COMPLETION_PROMPT"),
false))
assertThat(request.messages)
.extracting("role", "content")
.containsExactly(
Tuple.tuple("system", CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT),
Tuple.tuple("user", "TEST_PROMPT"),
Tuple.tuple("assistant", remainingMessage.response),
Tuple.tuple("user", "TEST_CHAT_COMPLETION_PROMPT"))
}
fun testTotalUsageExceededException() {
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(createDummyMessage(1500))
conversation.addMessage(createDummyMessage(1500))
conversation.addMessage(createDummyMessage(1500))
assertThrows(TotalUsageExceededException::class.java) {
CompletionRequestProvider(conversation)
.buildOpenAIChatCompletionRequest(
OpenAIChatCompletionModel.GPT_3_5.code,
CallParameters(
conversation,
ConversationType.DEFAULT,
createDummyMessage(100),
false)) }
}
private fun createDummyMessage(tokenSize: Int): Message {
return createDummyMessage("TEST_PROMPT", tokenSize)
}
private fun createDummyMessage(prompt: String, tokenSize: Int): Message {
val message = Message(prompt)
// 'zz' = 1 token, prompt = 6 tokens, 7 tokens per message (GPT-3),
message.response = "zz".repeat((tokenSize) - 6 - 7)
return message
}
}

View file

@ -0,0 +1,179 @@
package ee.carlrobert.codegpt.completions
import ee.carlrobert.codegpt.CodeGPTPlugin
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.llm.client.http.RequestEntity
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
import ee.carlrobert.llm.client.util.JSONUtil.e
import ee.carlrobert.llm.client.util.JSONUtil.jsonArray
import ee.carlrobert.llm.client.util.JSONUtil.jsonMap
import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse
import org.apache.http.HttpHeaders
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
class DefaultCompletionRequestHandlerTest : IntegrationTest() {
fun testOpenAIChatCompletionCall() {
useOpenAIService()
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "user", "content" to "TEST_PROMPT")))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
waitExpecting { "Hello!" == message.response }
}
fun testAzureChatCompletionCall() {
useAzureService()
val conversationService = ConversationService.getInstance()
val prevMessage = Message("TEST_PREV_PROMPT")
prevMessage.response = "TEST_PREV_RESPONSE"
val conversation = conversationService.startConversation()
conversation.addMessage(prevMessage)
conversationService.saveConversation(conversation)
expectAzure(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo(
"/openai/deployments/TEST_DEPLOYMENT_ID/chat/completions")
assertThat(request.uri.query).isEqualTo("api-version=TEST_API_VERSION")
assertThat(request.headers["Api-key"]!![0]).isEqualTo("TEST_API_KEY")
assertThat(request.headers["X-llm-application-tag"]!![0]).isEqualTo("codegpt")
assertThat(request.body)
.extracting("messages")
.isEqualTo(
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "user", "content" to "TEST_PREV_PROMPT"),
mapOf("role" to "assistant", "content" to "TEST_PREV_RESPONSE"),
mapOf("role" to "user", "content" to "TEST_PROMPT")))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
val message = Message("TEST_PROMPT")
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
waitExpecting { "Hello!" == message.response }
}
fun testYouChatCompletionCall() {
useYouService()
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(Message("Ping", "Pong"))
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
expectYou(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/api/streamingSearch")
assertThat(request.method).isEqualTo("GET")
assertThat(request.uri.path).isEqualTo("/api/streamingSearch")
assertThat(request.uri.query).isEqualTo(
"q=TEST_PROMPT&"
+ "page=1&"
+ "cfr=CodeGPT&"
+ "count=10&"
+ "safeSearch=WebPages,Translations,TimeZone,Computation,RelatedSearches&"
+ "domain=youchat&"
+ "selectedChatMode=default&"
+ "chat=[{\"question\":\"Ping\",\"answer\":\"Pong\"}]&"
+ "utm_source=ide&"
+ "utm_medium=jetbrains&"
+ "utm_campaign=" + CodeGPTPlugin.getVersion() + "&"
+ "utm_content=CodeGPT")
assertThat(request.headers)
.flatExtracting("Accept", "Connection", "User-agent", "Cookie")
.containsExactly(
"text/event-stream",
"Keep-Alive",
"youide CodeGPT",
"safesearch_guest=Moderate; "
+ "youpro_subscription=true; "
+ "you_subscription=free; "
+ "stytch_session=; "
+ "ydc_stytch_session=; "
+ "stytch_session_jwt=; "
+ "ydc_stytch_session_jwt=; "
+ "eg4=false; "
+ "__cf_bm=aN2b3pQMH8XADeMB7bg9s1bJ_bfXBcCHophfOGRg6g0-1693601599-0-"
+ "AWIt5Mr4Y3xQI4mIJ1lSf4+vijWKDobrty8OopDeBxY+NABe0MRFidF3dCUoWjRt8"
+ "SVMvBZPI3zkOgcRs7Mz3yazd7f7c58HwW5Xg9jdBjNg;")
listOf(
jsonMapResponse("youChatToken", "Hel"),
jsonMapResponse("youChatToken", "lo"),
jsonMapResponse("youChatToken", "!"))
})
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
waitExpecting { "Hello!" == message.response }
}
fun testLlamaChatCompletionCall() {
useLlamaService()
ConfigurationSettings.getCurrentState().maxTokens = 99
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(Message("Ping", "Pong"))
val requestHandler = CompletionRequestHandler(getRequestEventListener(message))
expectLlama(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/completion")
assertThat(request.body)
.extracting(
"prompt",
"n_predict",
"stream")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_PROMPT",
conversation.messages),
99,
true)
listOf<String?>(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)))
})
requestHandler.call(CallParameters(conversation, ConversationType.DEFAULT, message, false))
waitExpecting { "Hello!" == message.response }
}
private fun getRequestEventListener(message: Message): CompletionResponseEventListener {
return object : CompletionResponseEventListener {
override fun handleCompleted(fullMessage: String, callParameters: CallParameters) {
message.response = fullMessage
}
}
}
}

View file

@ -0,0 +1,149 @@
package ee.carlrobert.codegpt.completions
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA
import ee.carlrobert.codegpt.conversations.message.Message
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class PromptTemplateTest {
@Test
fun shouldBuildLlamaPromptWithHistory() {
val prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
assertThat(prompt).isEqualTo("""
<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>
[INST]TEST_PREV_PROMPT_1[/INST]
TEST_PREV_RESPONSE_1
[INST]TEST_PREV_PROMPT_2[/INST]
TEST_PREV_RESPONSE_2
[INST]TEST_USER_PROMPT[/INST]
""".trimIndent())
}
@Test
fun shouldBuildLlamaPromptWithoutHistory() {
val prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
assertThat(prompt).isEqualTo("""
<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>
[INST]TEST_USER_PROMPT[/INST]
""".trimIndent())
}
@Test
fun shouldBuildAlpacaPromptWithHistory() {
val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
assertThat(prompt).isEqualTo("""
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction
TEST_PREV_PROMPT_1
### Response:
TEST_PREV_RESPONSE_1
### Instruction
TEST_PREV_PROMPT_2
### Response:
TEST_PREV_RESPONSE_2
### Instruction
TEST_USER_PROMPT
### Response:
""".trimIndent())
}
@Test
fun shouldBuildAlpacaPromptWithoutHistory() {
val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
assertThat(prompt).isEqualTo("""
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction
TEST_USER_PROMPT
### Response:
""".trimIndent())
}
@Test
fun shouldBuildChatMLPromptWithHistory() {
val prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
assertThat(prompt).isEqualTo("""
<|im_start|>system
TEST_SYSTEM_PROMPT<|im_end|>
<|im_start|>user
TEST_PREV_PROMPT_1<|im_end|>
<|im_start|>assistant
TEST_PREV_RESPONSE_1<|im_end|>
<|im_start|>user
TEST_PREV_PROMPT_2<|im_end|>
<|im_start|>assistant
TEST_PREV_RESPONSE_2<|im_end|>
<|im_start|>user
TEST_USER_PROMPT<|im_end|>
""".trimIndent())
}
@Test
fun shouldBuildChatMLPromptWithoutHistory() {
val prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
assertThat(prompt).isEqualTo("""
<|im_start|>system
TEST_SYSTEM_PROMPT<|im_end|>
<|im_start|>user
TEST_USER_PROMPT<|im_end|>
""".trimIndent())
}
@Test
fun shouldBuildToRAPromptWithHistory() {
val prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
assertThat(prompt).isEqualTo("""
<|user|>
TEST_PREV_PROMPT_1
<|assistant|>
TEST_PREV_RESPONSE_1
<|user|>
TEST_PREV_PROMPT_2
<|assistant|>
TEST_PREV_RESPONSE_2
<|user|>
TEST_USER_PROMPT
<|assistant|>
""".trimIndent())
}
@Test
fun shouldBuildToRAPromptWithoutHistory() {
val prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
assertThat(prompt).isEqualTo("""
<|user|>
TEST_USER_PROMPT
<|assistant|>
""".trimIndent())
}
companion object {
private const val SYSTEM_PROMPT = "TEST_SYSTEM_PROMPT"
private const val USER_PROMPT = "TEST_USER_PROMPT"
private val HISTORY: List<Message> = listOf(
Message("TEST_PREV_PROMPT_1", "TEST_PREV_RESPONSE_1"),
Message("TEST_PREV_PROMPT_2", "TEST_PREV_RESPONSE_2")
)
}
}

View file

@ -0,0 +1,89 @@
package ee.carlrobert.codegpt.conversations
import com.intellij.testFramework.fixtures.BasePlatformTestCase
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel
import org.assertj.core.api.Assertions.assertThat
class ConversationsStateTest : BasePlatformTestCase() {
fun testStartNewDefaultConversation() {
GeneralSettings.getCurrentState().selectedService = ServiceType.OPENAI
OpenAISettings.getCurrentState().model = OpenAIChatCompletionModel.GPT_3_5.code
val conversation = ConversationService.getInstance().startConversation()
assertThat(conversation).isEqualTo(ConversationsState.getCurrentConversation())
assertThat(conversation)
.extracting("clientCode", "model")
.containsExactly("chat.completion", "gpt-3.5-turbo")
}
fun testSaveConversation() {
val service = ConversationService.getInstance()
val conversation = service.createConversation("chat.completion")
service.addConversation(conversation)
val message = Message("TEST_PROMPT")
message.response = "TEST_RESPONSE"
conversation.addMessage(message)
service.saveConversation(conversation)
val currentConversation = ConversationsState.getCurrentConversation()
assertThat(currentConversation).isNotNull()
assertThat(currentConversation!!.messages)
.flatExtracting("prompt", "response")
.containsExactly("TEST_PROMPT", "TEST_RESPONSE")
}
fun testGetPreviousConversation() {
val service = ConversationService.getInstance()
val firstConversation = service.startConversation()
service.startConversation()
val previousConversation = service.previousConversation
assertThat(previousConversation.isPresent).isTrue()
assertThat(previousConversation.get()).isEqualTo(firstConversation)
}
fun testGetNextConversation() {
val service = ConversationService.getInstance()
val firstConversation = service.startConversation()
val secondConversation = service.startConversation()
ConversationsState.getInstance().setCurrentConversation(firstConversation)
val nextConversation = service.nextConversation
assertThat(nextConversation.isPresent).isTrue()
assertThat(nextConversation.get()).isEqualTo(secondConversation)
}
fun testDeleteSelectedConversation() {
val service = ConversationService.getInstance()
val firstConversation = service.startConversation()
service.startConversation()
service.deleteSelectedConversation()
assertThat(ConversationsState.getCurrentConversation()).isEqualTo(firstConversation)
assertThat(service.sortedConversations.size).isEqualTo(1)
assertThat(service.sortedConversations)
.extracting("id")
.containsExactly(firstConversation.id)
}
fun testClearAllConversations() {
val service = ConversationService.getInstance()
service.startConversation()
service.startConversation()
service.clearAll()
assertThat(ConversationsState.getCurrentConversation()).isNull()
assertThat(service.sortedConversations.size).isEqualTo(0)
}
}

View file

@ -0,0 +1,79 @@
package ee.carlrobert.codegpt.settings.state
import com.intellij.testFramework.fixtures.BasePlatformTestCase
import ee.carlrobert.codegpt.completions.HuggingFaceModel
import ee.carlrobert.codegpt.conversations.Conversation
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import org.assertj.core.api.Assertions.assertThat
class GeneralSettingsTest : BasePlatformTestCase() {
fun testOpenAISettingsSync() {
val openAISettings = OpenAISettings.getCurrentState()
openAISettings.model = "gpt-3.5-turbo"
val conversation = Conversation()
conversation.model = "gpt-4"
conversation.clientCode = "chat.completion"
val settings = GeneralSettings.getInstance()
settings.sync(conversation)
assertThat(settings.state.selectedService).isEqualTo(ServiceType.OPENAI)
assertThat(openAISettings.model).isEqualTo("gpt-4")
}
fun testAzureSettingsSync() {
val settings = GeneralSettings.getInstance()
val conversation = Conversation()
conversation.model = "gpt-4"
conversation.clientCode = "azure.chat.completion"
settings.sync(conversation)
assertThat(settings.state.selectedService).isEqualTo(ServiceType.AZURE)
}
fun testYouSettingsSync() {
val settings = GeneralSettings.getInstance()
val conversation = Conversation()
conversation.model = "YouCode"
conversation.clientCode = "you.chat.completion"
settings.sync(conversation)
assertThat(settings.state.selectedService).isEqualTo(ServiceType.YOU)
}
fun testLlamaSettingsModelPathSync() {
val llamaSettings = LlamaSettings.getCurrentState()
llamaSettings.huggingFaceModel = HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3
val conversation = Conversation()
conversation.model = "TEST_LLAMA_MODEL_PATH"
conversation.clientCode = "llama.chat.completion"
val settings = GeneralSettings.getInstance()
settings.sync(conversation)
assertThat(settings.state.selectedService).isEqualTo(ServiceType.LLAMA_CPP)
assertThat(llamaSettings.customLlamaModelPath).isEqualTo("TEST_LLAMA_MODEL_PATH")
assertThat(llamaSettings.isUseCustomModel).isTrue()
}
fun testLlamaSettingsHuggingFaceModelSync() {
val llamaSettings = LlamaSettings.getCurrentState()
llamaSettings.huggingFaceModel = HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3
val conversation = Conversation()
conversation.model = "CODE_LLAMA_7B_Q3"
conversation.clientCode = "llama.chat.completion"
val settings = GeneralSettings.getInstance()
settings.sync(conversation)
assertThat(settings.state.selectedService).isEqualTo(ServiceType.LLAMA_CPP)
assertThat(llamaSettings.huggingFaceModel).isEqualTo(HuggingFaceModel.CODE_LLAMA_7B_Q3)
assertThat(llamaSettings.isUseCustomModel).isFalse()
}
}

View file

@ -0,0 +1,410 @@
package ee.carlrobert.codegpt.toolwindow.chat
import ee.carlrobert.codegpt.CodeGPTKeys
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.ReferencedFile
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.CompletionRequestProvider.FIX_COMPILE_ERRORS_SYSTEM_PROMPT
import ee.carlrobert.codegpt.completions.ConversationType
import ee.carlrobert.codegpt.completions.HuggingFaceModel
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.llm.client.http.RequestEntity
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
import ee.carlrobert.llm.client.util.JSONUtil.e
import ee.carlrobert.llm.client.util.JSONUtil.jsonArray
import ee.carlrobert.llm.client.util.JSONUtil.jsonMap
import ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse
import org.apache.http.HttpHeaders
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
import java.io.IOException
import java.nio.file.Files
import java.nio.file.Path
import java.util.Base64
import java.util.Objects
class ChatToolWindowTabPanelTest : IntegrationTest() {
fun testSendingOpenAIMessage() {
useOpenAIService()
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
val message = Message("Hello!")
val conversation = ConversationService.getInstance().startConversation()
val panel = ChatToolWindowTabPanel(project, conversation)
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "user", "content" to "Hello!")))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
panel.sendMessage(message)
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
}
val encodingManager = EncodingManager.getInstance()
assertThat(panel.tokenDetails).extracting(
"systemPromptTokens",
"conversationTokens",
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens(message.prompt),
0,
0)
assertThat(panel.conversation)
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.id,
conversation.model,
conversation.clientCode,
false)
val messages = panel.conversation.messages
assertThat(messages).hasSize(1)
assertThat(messages[0])
.extracting("id", "prompt", "response")
.containsExactly(message.id, message.prompt, message.response)
}
fun testSendingOpenAIMessageWithReferencedContext() {
project.putUserData(CodeGPTKeys.SELECTED_FILES, listOf(
ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"),
ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")))
useOpenAIService()
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
val message = Message("TEST_MESSAGE")
message.userMessage = "TEST_MESSAGE"
message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")
val conversation = ConversationService.getInstance().startConversation()
val panel = ChatToolWindowTabPanel(project, conversation)
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "user", "content" to """
Use the following context to answer question at the end:
File Path: TEST_FILE_PATH_1
File Content:
```TEST_FILE_NAME_1
TEST_FILE_CONTENT_1
```
File Path: TEST_FILE_PATH_2
File Content:
```TEST_FILE_NAME_2
TEST_FILE_CONTENT_2
```
File Path: TEST_FILE_PATH_3
File Content:
```TEST_FILE_NAME_3
TEST_FILE_CONTENT_3
```
Question: TEST_MESSAGE""".trimIndent())))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
panel.sendMessage(message)
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
}
val encodingManager = EncodingManager.getInstance()
assertThat(panel.tokenDetails).extracting(
"systemPromptTokens",
"conversationTokens",
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens(message.prompt),
0,
0)
assertThat(panel.conversation)
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.id,
conversation.model,
conversation.clientCode,
false)
val messages = panel.conversation.messages
assertThat(messages).hasSize(1)
assertThat(messages[0])
.extracting("id", "prompt", "response", "referencedFilePaths")
.containsExactly(
message.id,
message.prompt,
message.response,
listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"))
}
fun testSendingOpenAIMessageWithImage() {
val testImagePath = Objects.requireNonNull(javaClass.getResource("/images/test-image.png")).path
project.putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, testImagePath)
useOpenAIService("gpt-4-vision-preview")
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
val message = Message("TEST_MESSAGE")
val conversation = ConversationService.getInstance().startConversation()
val panel = ChatToolWindowTabPanel(project, conversation)
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
try {
val testImageUrl = ("data:image/png;base64,"
+ Base64.getEncoder().encodeToString(Files.readAllBytes(Path.of(testImagePath))))
assertThat(request.body)
.extracting("model", "messages")
.containsExactly(
"gpt-4-vision-preview",
listOf(
mapOf("role" to "system", "content" to COMPLETION_SYSTEM_PROMPT),
mapOf("role" to "user", "content" to listOf(
mapOf(
"type" to "image_url",
"image_url" to mapOf("url" to testImageUrl)),
mapOf("type" to "text", "text" to "TEST_MESSAGE")
))))
} catch (e: IOException) {
throw RuntimeException(e)
}
listOf<String?>(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
panel.sendMessage(message)
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
}
val encodingManager = EncodingManager.getInstance()
assertThat(panel.tokenDetails).extracting(
"systemPromptTokens",
"conversationTokens",
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens(message.prompt),
0,
0)
assertThat(panel.conversation)
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.id,
conversation.model,
conversation.clientCode,
false)
val messages = panel.conversation.messages
assertThat(messages).hasSize(1)
assertThat(messages[0])
.extracting("id", "prompt", "response", "imageFilePath")
.containsExactly(
message.id,
message.prompt,
message.response,
message.imageFilePath)
}
fun testFixCompileErrorsWithOpenAIService() {
project.putUserData(
CodeGPTKeys.SELECTED_FILES, listOf(
ReferencedFile("TEST_FILE_NAME_1", "TEST_FILE_PATH_1", "TEST_FILE_CONTENT_1"),
ReferencedFile("TEST_FILE_NAME_2", "TEST_FILE_PATH_2", "TEST_FILE_CONTENT_2"),
ReferencedFile("TEST_FILE_NAME_3", "TEST_FILE_PATH_3", "TEST_FILE_CONTENT_3")))
useOpenAIService()
ConfigurationSettings.getCurrentState().systemPrompt = COMPLETION_SYSTEM_PROMPT
val message = Message("TEST_MESSAGE")
message.userMessage = "TEST_MESSAGE"
message.referencedFilePaths = listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3")
val conversation = ConversationService.getInstance().startConversation()
val panel = ChatToolWindowTabPanel(project, conversation)
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.headers[HttpHeaders.AUTHORIZATION]!![0]).isEqualTo("Bearer TEST_API_KEY")
assertThat(request.body)
.extracting(
"model",
"messages")
.containsExactly(
"gpt-4",
listOf(
mapOf("role" to "system", "content" to FIX_COMPILE_ERRORS_SYSTEM_PROMPT),
mapOf("role" to "user", "content" to """
Use the following context to answer question at the end:
File Path: TEST_FILE_PATH_1
File Content:
```TEST_FILE_NAME_1
TEST_FILE_CONTENT_1
```
File Path: TEST_FILE_PATH_2
File Content:
```TEST_FILE_NAME_2
TEST_FILE_CONTENT_2
```
File Path: TEST_FILE_PATH_3
File Content:
```TEST_FILE_NAME_3
TEST_FILE_CONTENT_3
```
Question: TEST_MESSAGE""".trimIndent())))
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("role", "assistant")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "Hel")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "lo")))),
jsonMapResponse("choices", jsonArray(jsonMap("delta", jsonMap("content", "!")))))
})
panel.sendMessage(message, ConversationType.FIX_COMPILE_ERRORS)
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
}
val encodingManager = EncodingManager.getInstance()
assertThat(panel.tokenDetails).extracting(
"systemPromptTokens",
"conversationTokens",
"userPromptTokens",
"highlightedTokens")
.containsExactly(
encodingManager.countTokens(COMPLETION_SYSTEM_PROMPT),
encodingManager.countTokens(message.prompt),
0,
0)
assertThat(panel.conversation)
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.id,
conversation.model,
conversation.clientCode,
false)
val messages = panel.conversation.messages
assertThat(messages).hasSize(1)
assertThat(messages[0])
.extracting("id", "prompt", "response", "referencedFilePaths")
.containsExactly(
message.id,
message.prompt,
message.response,
listOf("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"))
}
fun testSendingLlamaMessage() {
useLlamaService()
val configurationState = ConfigurationSettings.getCurrentState()
configurationState.systemPrompt = COMPLETION_SYSTEM_PROMPT
configurationState.maxTokens = 1000
configurationState.temperature = 0.1
val llamaSettings = LlamaSettings.getCurrentState()
llamaSettings.isUseCustomModel = false
llamaSettings.huggingFaceModel = HuggingFaceModel.CODE_LLAMA_7B_Q4
llamaSettings.topK = 30
llamaSettings.topP = 0.8
llamaSettings.minP = 0.03
llamaSettings.repeatPenalty = 1.3
val message = Message("TEST_PROMPT")
val conversation = ConversationService.getInstance().startConversation()
val panel = ChatToolWindowTabPanel(project, conversation)
expectLlama(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/completion")
assertThat(request.body)
.extracting(
"prompt",
"n_predict",
"stream",
"temperature",
"top_k",
"top_p",
"min_p",
"repeat_penalty")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_PROMPT",
conversation.messages),
configurationState.maxTokens,
true,
configurationState.temperature,
llamaSettings.topK,
llamaSettings.topP,
llamaSettings.minP,
llamaSettings.repeatPenalty)
listOf<String?>(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)))
})
panel.sendMessage(message, ConversationType.DEFAULT)
waitExpecting {
val messages = conversation.messages
messages.isNotEmpty() && "Hello!" == messages[0].response
}
assertThat(panel.conversation)
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.id,
conversation.model,
conversation.clientCode,
false)
val messages = panel.conversation.messages
assertThat(messages).hasSize(1)
assertThat(messages[0])
.extracting("id", "prompt", "response")
.containsExactly(message.id, message.prompt, message.response)
}
}

View file

@ -0,0 +1,50 @@
package ee.carlrobert.codegpt.toolwindow.chat
import com.intellij.openapi.util.Disposer
import com.intellij.testFramework.fixtures.BasePlatformTestCase
import ee.carlrobert.codegpt.conversations.ConversationService
import ee.carlrobert.codegpt.conversations.message.Message
import org.assertj.core.api.Assertions.assertThat
class ChatToolWindowTabbedPaneTest : BasePlatformTestCase() {
fun testClearAllTabs() {
val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())
tabbedPane.addNewTab(createNewTabPanel())
tabbedPane.clearAll()
assertThat(tabbedPane.activeTabMapping).isEmpty()
}
fun testAddingNewTabs() {
val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())
tabbedPane.addNewTab(createNewTabPanel())
tabbedPane.addNewTab(createNewTabPanel())
tabbedPane.addNewTab(createNewTabPanel())
assertThat(tabbedPane.activeTabMapping.keys)
.containsExactly("Chat 1", "Chat 2", "Chat 3")
}
fun testResetCurrentlyActiveTabPanel() {
val tabbedPane = ChatToolWindowTabbedPane(Disposer.newDisposable())
val conversation = ConversationService.getInstance().startConversation()
conversation.addMessage(Message("TEST_PROMPT", "TEST_RESPONSE"))
tabbedPane.addNewTab(ChatToolWindowTabPanel(project, conversation))
tabbedPane.resetCurrentlyActiveTabPanel(project)
val tabPanel = tabbedPane.activeTabMapping["Chat 1"]
assertThat(tabPanel!!.conversation.messages).isEmpty()
}
private fun createNewTabPanel(): ChatToolWindowTabPanel {
return ChatToolWindowTabPanel(
project,
ConversationService.getInstance().startConversation()
)
}
}

View file

@ -1,14 +1,13 @@
package ee.carlrobert.codegpt.util;
package ee.carlrobert.codegpt.util
import static org.assertj.core.api.Assertions.assertThat;
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import org.junit.Test;
public class MarkdownUtilTest {
class MarkdownUtilTest {
@Test
public void shouldExtractMarkdownCodeBlocks() {
String testInput = """
fun shouldExtractMarkdownCodeBlocks() {
val testInput = """
**C++ Code Block**
```cpp
#include <iostream>
@ -29,60 +28,68 @@ public class MarkdownUtilTest {
```
1. We define a **public class** called **Main**.
2. We define the **main** method which is the entry point of the program.
""";
var result = MarkdownUtil.splitCodeBlocks(testInput);
""".trimIndent()
val result = MarkdownUtil.splitCodeBlocks(testInput)
assertThat(result).containsExactly("""
**C++ Code Block**
""", """
""".trimIndent(), """
```cpp
#include <iostream>
int main() {
return 0;
}
```""", """
```""".trimIndent(), """
1. We include the **iostream** header file.
2. We define the main function.
**Java Code Block**
""", """
""".trimIndent(), """
```java
public class Main {
public static void main(String[] args) {
}
}
```""", """
```""".trimIndent(), """
1. We define a **public class** called **Main**.
2. We define the **main** method which is the entry point of the program.
""");
""".trimIndent())
}
@Test
public void shouldExtractMarkdownWithoutCode() {
String testInput = """
fun shouldExtractMarkdownWithoutCode() {
val testInput = """
**C++ Code Block**
1. We include the **iostream** header file.
2. We define the main function.
""";
var result = MarkdownUtil.splitCodeBlocks(testInput);
""".trimIndent()
val result = MarkdownUtil.splitCodeBlocks(testInput)
assertThat(result).containsExactly("""
**C++ Code Block**
1. We include the **iostream** header file.
2. We define the main function.
""");
""".trimIndent())
}
@Test
public void shouldExtractMarkdownCodeOnly() {
String testInput = """
fun shouldExtractMarkdownCodeOnly() {
val testInput = """
```cpp
#include <iostream>
@ -96,9 +103,10 @@ public class MarkdownUtilTest {
}
}
```
""";
var result = MarkdownUtil.splitCodeBlocks(testInput);
""".trimIndent()
val result = MarkdownUtil.splitCodeBlocks(testInput)
assertThat(result).containsExactly("""
```cpp
@ -107,12 +115,12 @@ public class MarkdownUtilTest {
int main() {
return 0;
}
```""", """
```""".trimIndent(), """
```java
public class Main {
public static void main(String[] args) {
}
}
```""");
```""".trimIndent())
}
}

View file

@ -0,0 +1,33 @@
package testsupport
import com.intellij.openapi.util.Key
import com.intellij.testFramework.fixtures.BasePlatformTestCase
import ee.carlrobert.codegpt.CodeGPTKeys
import ee.carlrobert.llm.client.mixin.ExternalServiceTestMixin
import testsupport.mixin.ShortcutsTestMixin
open class IntegrationTest : BasePlatformTestCase(), ExternalServiceTestMixin, ShortcutsTestMixin {
@Throws(Exception::class)
override fun tearDown() {
ExternalServiceTestMixin.clearAll()
clearKeys()
super.tearDown()
}
private fun clearKeys() {
putUserData(CodeGPTKeys.SELECTED_FILES, emptyList())
putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, "")
putUserData(CodeGPTKeys.IMAGE_ATTACHMENT_FILE_PATH, "")
}
private fun <T> putUserData(key: Key<T>, value: T) {
project.putUserData(key, value)
}
companion object {
init {
ExternalServiceTestMixin.init()
}
}
}

View file

@ -0,0 +1,52 @@
package testsupport.mixin
import com.intellij.testFramework.PlatformTestUtil
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.AZURE_OPENAI_API_KEY
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.OPENAI_API_KEY
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import java.util.function.BooleanSupplier
interface ShortcutsTestMixin {
fun useOpenAIService() {
useOpenAIService("gpt-4")
}
fun useOpenAIService(model: String? = "gpt-4") {
GeneralSettings.getCurrentState().selectedService = ServiceType.OPENAI
setCredential(OPENAI_API_KEY, "TEST_API_KEY")
OpenAISettings.getCurrentState().model = model
}
fun useAzureService() {
GeneralSettings.getCurrentState().selectedService = ServiceType.AZURE
setCredential(AZURE_OPENAI_API_KEY, "TEST_API_KEY")
val azureSettings = AzureSettings.getCurrentState()
azureSettings.resourceName = "TEST_RESOURCE_NAME"
azureSettings.apiVersion = "TEST_API_VERSION"
azureSettings.deploymentId = "TEST_DEPLOYMENT_ID"
}
fun useYouService() {
GeneralSettings.getCurrentState().selectedService = ServiceType.YOU
}
fun useLlamaService() {
GeneralSettings.getCurrentState().selectedService = ServiceType.LLAMA_CPP
LlamaSettings.getCurrentState().serverPort = null
}
fun waitExpecting(condition: BooleanSupplier?) {
PlatformTestUtil.waitWithEventsDispatching(
"Waiting for message response timed out or did not meet expected conditions",
condition!!,
5
)
}
}