#178 - Add support for running local LLMs via LLaMA C/C++ port (#249)

* Initial implementation of integrating llama.cpp to run LLaMA models locally

* Move submodule

* Copy llama submodule to bundle

* Support for downloading models from IDE

* Code cleanup

* Store port field

* Replace service selection radio group with dropdown

* Add quantization support + other fixes

* Add option to override host

* Fix override host handler

* Disable port field when override host enabled

* Design updates

* Fix llama settings configuration, design changes, clean up code

* Improve You.com coupon design

* Add new Phind model and help tooltip

* Fetch you.com subscription

* Add CodeBooga model, fix downloadable model selection

* Chat history support

* Code refactoring, minor bug fixes

* UI updates, several bug fixes, removed code llama python model

* Code cleanup, enable llama port only on macOS

* Change downloaded gguf models path

* Move some of the labels to codegpt bundle

* Minor fixes

* Remove ToRA model, add help texts

* Fix test

* Modify description
This commit is contained in:
Carl-Robert 2023-11-03 12:00:24 +02:00 committed by GitHub
parent ca2eb9b6fa
commit 45908e69df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
71 changed files with 2748 additions and 533 deletions

View file

@ -1,6 +1,8 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
import static ee.carlrobert.llm.client.util.JSONUtil.e;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap;
import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse;
@ -10,14 +12,17 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;
import com.intellij.testFramework.fixtures.BasePlatformTestCase;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.conversations.ConversationService;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.credentials.AzureCredentialsManager;
import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationState;
import ee.carlrobert.codegpt.settings.state.AzureSettingsState;
import ee.carlrobert.codegpt.settings.state.LlamaSettingsState;
import ee.carlrobert.codegpt.settings.state.OpenAISettingsState;
import ee.carlrobert.codegpt.settings.state.SettingsState;
import ee.carlrobert.codegpt.settings.state.YouSettingsState;
import ee.carlrobert.llm.client.http.LocalCallbackServer;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import ee.carlrobert.llm.client.http.expectation.StreamExpectation;
@ -34,8 +39,11 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase {
super.setUp();
AzureCredentialsManager.getInstance().setApiKey("TEST_API_KEY");
OpenAICredentialsManager.getInstance().setApiKey("TEST_API_KEY");
// FIXME
OpenAISettingsState.getInstance().setBaseHost("http://127.0.0.1:8000");
AzureSettingsState.getInstance().setBaseHost("http://127.0.0.1:8000");
YouSettingsState.getInstance().setBaseHost("http://127.0.0.1:8000");
LlamaSettingsState.getInstance().setServerPort(8000);
ConfigurationState.getInstance().setSystemPrompt("");
server = new LocalCallbackServer(8000);
}
@ -46,7 +54,7 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase {
super.tearDown();
}
public void testChatCompletionCall() {
public void testOpenAIChatCompletionCall() {
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
var requestHandler = new CompletionRequestHandler();
@ -54,6 +62,8 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase {
var settings = SettingsState.getInstance();
settings.setUseOpenAIService(true);
settings.setUseAzureService(false);
settings.setUseYouService(false);
settings.setUseLlamaService(false);
expectStreamRequest("/v1/chat/completions", request -> {
assertThat(request.getMethod()).isEqualTo("POST");
assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY");
@ -84,6 +94,8 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase {
var settings = SettingsState.getInstance();
settings.setUseOpenAIService(false);
settings.setUseAzureService(true);
settings.setUseYouService(false);
settings.setUseLlamaService(false);
var azureSettings = AzureSettingsState.getInstance();
azureSettings.setResourceName("TEST_RESOURCE_NAME");
azureSettings.setApiVersion("TEST_API_VERSION");
@ -123,6 +135,97 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase {
await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse()));
}
public void testYouChatCompletionCall() {
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
conversation.addMessage(new Message("Ping", "Pong"));
var requestHandler = new CompletionRequestHandler();
requestHandler.addRequestCompletedListener(message::setResponse);
var settings = SettingsState.getInstance();
settings.setUseOpenAIService(false);
settings.setUseAzureService(false);
settings.setUseYouService(true);
settings.setUseLlamaService(false);
expectStreamRequest("/api/streamingSearch", request -> {
assertThat(request.getMethod()).isEqualTo("GET");
assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch");
assertThat(request.getUri().getQuery()).isEqualTo(
"q=TEST_PROMPT&" +
"page=1&" +
"cfr=CodeGPT&" +
"count=10&" +
"safeSearch=WebPages,Translations,TimeZone,Computation,RelatedSearches&" +
"domain=youchat&" +
"chat=[{\"question\":\"Ping\",\"answer\":\"Pong\"}]&" +
"utm_source=ide&" +
"utm_medium=jetbrains&" +
"utm_campaign=" + CodeGPTPlugin.getVersion() + "&" +
"utm_content=CodeGPT");
assertThat(request.getHeaders())
.flatExtracting("Host", "Accept", "Connection", "User-agent", "Cookie")
.containsExactly("127.0.0.1:8000",
"text/event-stream",
"Keep-Alive",
"youide CodeGPT",
"safesearch_guest=Moderate; " +
"youpro_subscription=true; " +
"you_subscription=free; " +
"stytch_session=; " +
"ydc_stytch_session=; " +
"stytch_session_jwt=; " +
"ydc_stytch_session_jwt=; " +
"eg4=false; " +
"safesearch_9015f218b47611b62bbbaf61125cd2dac629e65c3d6f47573a2ec0e9b615c691=Moderate; "
+
"__cf_bm=aN2b3pQMH8XADeMB7bg9s1bJ_bfXBcCHophfOGRg6g0-1693601599-0-AWIt5Mr4Y3xQI4mIJ1lSf4+vijWKDobrty8OopDeBxY+NABe0MRFidF3dCUoWjRt8SVMvBZPI3zkOgcRs7Mz3yazd7f7c58HwW5Xg9jdBjNg;");
return List.of(
jsonMapResponse("youChatToken", "Hel"),
jsonMapResponse("youChatToken", "lo"),
jsonMapResponse("youChatToken", "!"));
});
requestHandler.call(conversation, message, false);
await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse()));
}
public void testLlamaChatCompletionCall() {
var message = new Message("TEST_PROMPT");
var conversation = ConversationService.getInstance().startConversation();
conversation.addMessage(new Message("Ping", "Pong"));
var requestHandler = new CompletionRequestHandler();
requestHandler.addRequestCompletedListener(message::setResponse);
var settings = SettingsState.getInstance();
settings.setUseOpenAIService(false);
settings.setUseAzureService(false);
settings.setUseYouService(false);
settings.setUseLlamaService(true);
expectStreamRequest("/completion", request -> {
assertThat(request.getBody())
.extracting(
"prompt",
"n_predict",
"stream")
.containsExactly(
LLAMA.buildPrompt(
COMPLETION_SYSTEM_PROMPT,
"TEST_PROMPT",
conversation.getMessages()),
512,
true);
return List.of(
jsonMapResponse("content", "Hel"),
jsonMapResponse("content", "lo!"),
jsonMapResponse(
e("content", ""),
e("stop", true)));
});
requestHandler.call(conversation, message, false);
await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse()));
}
private void expectStreamRequest(String path, StreamHttpExchange exchange) {
server.addExpectation(new StreamExpectation(path, exchange));
}

View file

@ -0,0 +1,141 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA;
import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA;
import static org.assertj.core.api.Assertions.assertThat;
import ee.carlrobert.codegpt.conversations.message.Message;
import java.util.List;
import org.junit.Test;
public class PromptTemplateTest {
private static final String SYSTEM_PROMPT = "TEST_SYSTEM_PROMPT";
private static final String USER_PROMPT = "TEST_USER_PROMPT";
private static final List<Message> HISTORY = List.of(
new Message("TEST_PREV_PROMPT_1", "TEST_PREV_RESPONSE_1"),
new Message("TEST_PREV_PROMPT_2", "TEST_PREV_RESPONSE_2"));
@Test
public void shouldBuildLlamaPromptWithHistory() {
var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
assertThat(prompt).isEqualTo(
"<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>\n"
+ "[INST]TEST_PREV_PROMPT_1[/INST]\n"
+ "TEST_PREV_RESPONSE_1\n"
+ "[INST]TEST_PREV_PROMPT_2[/INST]\n"
+ "TEST_PREV_RESPONSE_2\n"
+ "[INST]TEST_USER_PROMPT[/INST]");
}
@Test
public void shouldBuildLlamaPromptWithoutHistory() {
var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
assertThat(prompt).isEqualTo(
"<<SYS>>TEST_SYSTEM_PROMPT<</SYS>>\n"
+ "[INST]TEST_USER_PROMPT[/INST]");
}
@Test
public void shouldBuildAlpacaPromptWithHistory() {
var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
assertThat(prompt).isEqualTo(
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"
+ "\n"
+ "### Instruction\n"
+ "TEST_PREV_PROMPT_1\n"
+ "\n"
+ "### Response:\n"
+ "TEST_PREV_RESPONSE_1\n"
+ "\n"
+ "### Instruction\n"
+ "TEST_PREV_PROMPT_2\n"
+ "\n"
+ "### Response:\n"
+ "TEST_PREV_RESPONSE_2\n"
+ "\n"
+ "### Instruction\n"
+ "TEST_USER_PROMPT\n"
+ "\n"
+ "### Response:\n");
}
@Test
public void shouldBuildAlpacaPromptWithoutHistory() {
var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
assertThat(prompt).isEqualTo(
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"
+ "\n"
+ "### Instruction\n"
+ "TEST_USER_PROMPT\n"
+ "\n"
+ "### Response:\n");
}
@Test
public void shouldBuildChatMLPromptWithHistory() {
var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
assertThat(prompt).isEqualTo(
"<|im_start|>system\n"
+ "TEST_SYSTEM_PROMPT<|im_end|>\n"
+ "<|im_start|>user\n"
+ "TEST_PREV_PROMPT_1<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ "TEST_PREV_RESPONSE_1<|im_end|>\n"
+ "<|im_start|>user\n"
+ "TEST_PREV_PROMPT_2<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ "TEST_PREV_RESPONSE_2<|im_end|>\n"
+ "<|im_start|>user\n"
+ "TEST_USER_PROMPT<|im_end|>"
);
}
@Test
public void shouldBuildChatMLPromptWithoutHistory() {
var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
assertThat(prompt).isEqualTo(
"<|im_start|>system\n"
+ "TEST_SYSTEM_PROMPT<|im_end|>\n"
+ "<|im_start|>user\n"
+ "TEST_USER_PROMPT<|im_end|>");
}
@Test
public void shouldBuildToRAPromptWithHistory() {
var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY);
assertThat(prompt).isEqualTo(
"<|user|>\n"
+ "TEST_PREV_PROMPT_1\n"
+ "<|assistant|>\n"
+ "TEST_PREV_RESPONSE_1\n"
+ "<|user|>\n"
+ "TEST_PREV_PROMPT_2\n"
+ "<|assistant|>\n"
+ "TEST_PREV_RESPONSE_2\n"
+ "<|user|>\n"
+ "TEST_USER_PROMPT\n"
+ "<|assistant|>"
);
}
@Test
public void shouldBuildToRAPromptWithoutHistory() {
var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of());
assertThat(prompt).isEqualTo(
"<|user|>\n"
+ "TEST_USER_PROMPT\n"
+ "<|assistant|>"
);
}
}