feat: extract llama request settings to its own state, improve UI/UX

This commit is contained in:
Carl-Robert Linnupuu 2023-12-21 14:46:45 +02:00
parent 9d83107dd5
commit e230640063
15 changed files with 446 additions and 267 deletions

View file

@ -1,6 +1,8 @@
package ee.carlrobert.codegpt.toolwindow.chat;
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
import static ee.carlrobert.llm.client.util.JSONUtil.e;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
@ -11,13 +13,14 @@ import static org.awaitility.Awaitility.await;
import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.completions.HuggingFaceModel;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.toolwindow.chat.standard.StandardChatToolWindowTabPanel;
import ee.carlrobert.embedding.CheckedFile;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import testsupport.IntegrationTest;
@ -166,4 +169,74 @@ public class StandardChatToolWindowTabPanelTest extends IntegrationTest {
message.getResponse(),
List.of("TEST_FILE_PATH_1", "TEST_FILE_PATH_2", "TEST_FILE_PATH_3"));
}
public void testSendingLlamaMessage() {
useLlamaService();
var configurationState = ConfigurationState.getInstance();
configurationState.setSystemPrompt(COMPLETION_SYSTEM_PROMPT);
configurationState.setMaxTokens(1000);
configurationState.setTemperature(0.1);
var llamaSettings = LlamaSettingsState.getInstance();
llamaSettings.setUseCustomModel(false);
llamaSettings.setHuggingFaceModel(HuggingFaceModel.CODE_LLAMA_7B_Q4);
llamaSettings.setTopK(30);
llamaSettings.setTopP(0.8);
llamaSettings.setMinP(0.03);
llamaSettings.setRepeatPenalty(1.3);
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
var panel = new StandardChatToolWindowTabPanel(getProject(), conversation);
expectLlama((StreamHttpExchange) request -> {
assertThat(request.getUri().getPath()).isEqualTo("/completion");
assertThat(request.getBody())
.extracting(
"prompt",
"n_predict",
"stream",
"temperature",
"top_k",
"top_p",
"min_p",
"repeat_penalty")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_PROMPT",
conversation.getMessages()),
configurationState.getMaxTokens(),
true,
configurationState.getTemperature(),
llamaSettings.getTopK(),
llamaSettings.getTopP(),
llamaSettings.getMinP(),
llamaSettings.getRepeatPenalty());
return List.of(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)));
});
panel.sendMessage(message);
await().atMost(5, SECONDS)
.until(() -> {
var messages = conversation.getMessages();
return !messages.isEmpty() && "Hello!".equals(messages.get(0).getResponse());
});
assertThat(panel.getConversation())
.isNotNull()
.extracting("id", "model", "clientCode", "discardTokenLimit")
.containsExactly(
conversation.getId(),
conversation.getModel(),
conversation.getClientCode(),
false);
var messages = panel.getConversation().getMessages();
assertThat(messages.size()).isOne();
assertThat(messages.get(0))
.extracting("id", "prompt", "response")
.containsExactly(message.getId(), message.getPrompt(), message.getResponse());
}
}