diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..215acb1e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "llama.cpp"] + path = src/main/cpp/llama.cpp + url = https://github.com/ggerganov/llama.cpp diff --git a/DESCRIPTION.md b/DESCRIPTION.md index f45c5293..548aeaa9 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -1,24 +1,66 @@ -**ChatGPT as your copilot to level up your developer experience.** +## Introducing CodeGPT: Your Free, Open-Source AI Copilot for Coding -This is the perfect assistant for any programmer who wants to improve their coding skills -and make more efficient use of the time. +CodeGPT is your go-to AI assistant, designed to enhance your coding skills and optimize your programming time. +Access state-of-the-art LLMs like GPT-4, Code LLama and more, all for free. -## Getting Started +## Quick Start Guide -### Prerequisites +1. **Download the Plugin**: Get started by downloading the plugin from the [JetBrains Marketplace](https://plugins.jetbrains.com/plugin/21056-codegpt?preview=true). -In order to use the extension, you need to have the API key configured. You can find the API key in -your [User settings](https://platform.openai.com/account/api-keys). +2. **Choose Your Preferred Service**: -### API Key Configuration + a) **OpenAI** - Requires authentication via OpenAI API key. -After the plugin has been successfully installed, the API key needs to be configured. + b) **Azure** - Requires authentication via Active Directory or API key. -You can configure the key by going to the plugin's settings via the **File | Settings/Preferences | Tools | CodeGPT**. On the settings panel simply -click -on the API key field, paste the key obtained from the OpenAI website and click **Apply/OK**. + c) **You.com** - A free, web-connected service with an optional upgrade to You⚡Pro for enhanced features.. + + d) **LLaMA C/C++ Port** - Run Code Llama, WizardCoder, and other state-of-the-art models locally for free. + +3. **Start Using the Features**: You're all set! Start exploring the features of our plugin. + +### OpenAI + +After successful installation, configure your API key. Navigate to the plugin's settings via **File | Settings/Preferences | Tools | CodeGPT**. Paste your OpenAI API key into the field and click `Apply/OK`. + +### Azure + +For Azure OpenAI services, you'll need to input three additional fields: +* `Resource name`: The name of your Azure OpenAI Cognitive Services. +* `Deployment ID`: The name of your Deployment. +* `API version`: The most recent non-preview version. + +Also, input one of the two provided API keys. + +### You.com (Free) + +**You**.com is a search engine that summarizes the best parts of the internet for **you**, with private ads and with privacy options. + +**You⚡Pro** + +Use the **CodeGPT** coupon for a free month of unlimited GPT-4 usage. + +Check out the full [feature list](https://about.you.com/hc/youpro/what-features-are-included-in-youpro/) for more details. + +### LLaMA C/C++ Port (Free, Local) + +> **Note**: This feature is currently supported only on Linux and MacOS. + +The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quantization on a MacBook. + +#### Getting Started + +1. **Select the Model**: Depending on your hardware capabilities, choose the appropriate model from the provided list. Once selected, click on the `Download Model` link. A progress bar will appear, indicating the download process. + +2. **Start the Server**: After successfully downloading the model, initiate the server by clicking on the `Start Server` button. A status message will be displayed, indicating that the server is starting up. + +3. **Apply Settings**: With the server running, you can now apply the settings to start using the features. Click on the `Apply/OK` button to save your settings and start using the application. + +animated + +> **Note**: If you're already running a server and wish to configure the plugin against that, then simply select the port and click `Apply/OK`. ## Features diff --git a/README.md b/README.md index cd4ba38b..31960365 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,21 @@ Expected a different answer? Re-generate any response of your choosing. - **Seamless conversations** - Chat with the AI regardless of the maximum token limitations - **Predefined Actions** - Create your own editor actions or override the existing ones, saving time rewriting the same prompt repeatedly +### Running locally + +**Linux or macOS** +```shell +git clone https://github.com/carlrobertoh/CodeGPT.git +cd CodeGPT +git submodule update +./gradlew runIde +``` + +**Windows ARM64** +```shell +./gradlew runIde -Penv=win-arm64 +``` + ## Issues See the [open issues][open-issues] for a full list of proposed features (and known issues). diff --git a/build.gradle.kts b/build.gradle.kts index 44e74a2b..051614ea 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,8 +1,24 @@ import org.gradle.api.tasks.testing.logging.TestExceptionFormat import org.jetbrains.changelog.Changelog import org.jetbrains.changelog.markdownToHTML +import java.io.FileInputStream +import java.util.* + +val env = environment("env").getOrNull() + +fun loadProperties(filename: String): Properties = Properties().apply { + load(FileInputStream(filename)) +} + +fun properties(key: String): Provider { + if ("win-arm64" == env) { + val property = loadProperties("gradle-win-arm64.properties").getProperty(key) + ?: return providers.gradleProperty(key) + return providers.provider { property } + } + return providers.gradleProperty(key) +} -fun properties(key: String) = providers.gradleProperty(key) fun environment(key: String) = providers.environmentVariable(key) plugins { @@ -53,6 +69,17 @@ dependencies { testRuntimeOnly("org.junit.vintage:junit-vintage-engine:5.10.0") } +tasks.register("updateSubmodules") { + workingDir(rootDir) + commandLine("git", "submodule", "update", "--init", "--recursive") +} + +tasks.register("copyLlamaSubmodule") { + dependsOn("updateSubmodules") + from(layout.projectDirectory.file("src/main/cpp/llama.cpp")) + into(layout.buildDirectory.dir("idea-sandbox/plugins/CodeGPT/llama.cpp")) +} + tasks { wrapper { gradleVersion = properties("gradleVersion").get() @@ -98,6 +125,11 @@ tasks { }) } + prepareSandbox { + enabled = true + dependsOn("copyLlamaSubmodule") + } + signPlugin { enabled = true certificateChain.set(System.getenv("CERTIFICATE_CHAIN")) @@ -105,6 +137,10 @@ tasks { password.set(System.getenv("PRIVATE_KEY_PASSWORD")) } + buildPlugin { + enabled = true + } + publishPlugin { enabled = true dependsOn("patchChangelog") @@ -125,4 +161,4 @@ tasks { showStandardStreams = true } } -} +} \ No newline at end of file diff --git a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts index 72b32f93..a8ec1e7b 100644 --- a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts @@ -18,7 +18,7 @@ intellij { } dependencies { - implementation("ee.carlrobert:llm-client:0.0.6") + implementation("ee.carlrobert:llm-client:0.0.7") } tasks { diff --git a/docs/assets/llama_settings.png b/docs/assets/llama_settings.png new file mode 100644 index 00000000..0d5c45c7 Binary files /dev/null and b/docs/assets/llama_settings.png differ diff --git a/gradle-win-arm64.properties b/gradle-win-arm64.properties new file mode 100644 index 00000000..106b0fd5 --- /dev/null +++ b/gradle-win-arm64.properties @@ -0,0 +1,2 @@ +platformVersion = 2023.1 +javaVersion = 17 \ No newline at end of file diff --git a/gradle.properties b/gradle.properties index 28464df3..7aee66ae 100644 --- a/gradle.properties +++ b/gradle.properties @@ -32,6 +32,8 @@ org.gradle.configuration-cache = true # Enable Gradle Build Cache -> https://docs.gradle.org/current/userguide/build_cache.html org.gradle.caching = true +# org.gradle.logging.level=debug + # Enable Gradle Kotlin DSL Lazy Property Assignment -> https://docs.gradle.org/current/userguide/kotlin_dsl.html#kotdsl:assignment systemProp.org.gradle.unsafe.kotlin.assignment = true diff --git a/src/main/cpp/llama.cpp b/src/main/cpp/llama.cpp new file mode 160000 index 00000000..b8fe4b5c --- /dev/null +++ b/src/main/cpp/llama.cpp @@ -0,0 +1 @@ +Subproject commit b8fe4b5cc9cb237ca98e5bc51b5d189e3c446d13 diff --git a/src/main/java/ee/carlrobert/codegpt/CodeGPTPlugin.java b/src/main/java/ee/carlrobert/codegpt/CodeGPTPlugin.java index 0affdaa6..65a8b21b 100644 --- a/src/main/java/ee/carlrobert/codegpt/CodeGPTPlugin.java +++ b/src/main/java/ee/carlrobert/codegpt/CodeGPTPlugin.java @@ -6,8 +6,10 @@ import com.intellij.ide.plugins.PluginManagerCore; import com.intellij.openapi.application.PathManager; import com.intellij.openapi.extensions.PluginId; import com.intellij.openapi.project.Project; +import ee.carlrobert.codegpt.telemetry.core.util.Directories; import java.io.File; import java.nio.file.Path; +import java.nio.file.Paths; import org.jetbrains.annotations.NotNull; public final class CodeGPTPlugin { @@ -33,6 +35,14 @@ public final class CodeGPTPlugin { return getPluginOptionsPath() + File.separator + "indexes"; } + public static @NotNull String getLlamaSourcePath() { + return getPluginBasePath() + File.separator + "llama.cpp"; + } + + public static @NotNull String getLlamaModelsPath() { + return Paths.get(System.getProperty("user.home"), ".codegpt/models/gguf").toString(); + } + public static @NotNull String getProjectIndexStorePath(@NotNull Project project) { return getIndexStorePath() + File.separator + project.getName(); } diff --git a/src/main/java/ee/carlrobert/codegpt/Icons.java b/src/main/java/ee/carlrobert/codegpt/Icons.java index 9c4b2ca9..3eef6b2a 100644 --- a/src/main/java/ee/carlrobert/codegpt/Icons.java +++ b/src/main/java/ee/carlrobert/codegpt/Icons.java @@ -11,4 +11,5 @@ public final class Icons { public static final Icon OpenAIIcon = IconLoader.getIcon("/icons/openai.svg", Icons.class); public static final Icon AzureIcon = IconLoader.getIcon("/icons/azure.svg", Icons.class); public static final Icon YouIcon = IconLoader.getIcon("/icons/you.svg", Icons.class); + public static final Icon LlamaIcon = IconLoader.getIcon("/icons/llama.svg", Icons.class); } diff --git a/src/main/java/ee/carlrobert/codegpt/actions/editor/EditorActionsUtil.java b/src/main/java/ee/carlrobert/codegpt/actions/editor/EditorActionsUtil.java index a43fc7d4..54c6dce5 100644 --- a/src/main/java/ee/carlrobert/codegpt/actions/editor/EditorActionsUtil.java +++ b/src/main/java/ee/carlrobert/codegpt/actions/editor/EditorActionsUtil.java @@ -38,7 +38,7 @@ public class EditorActionsUtil { } public static void refreshActions() { - AnAction actionGroup = ActionManager.getInstance().getAction("action.editor.group.EditorActionGroup"); + AnAction actionGroup = ActionManager.getInstance().getAction("project.label"); if (actionGroup instanceof DefaultActionGroup) { DefaultActionGroup group = (DefaultActionGroup) actionGroup; group.removeAll(); diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java index fd291e15..7f3767ca 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionClientProvider.java @@ -5,11 +5,14 @@ import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; import ee.carlrobert.codegpt.settings.advanced.AdvancedSettingsState; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; +import ee.carlrobert.codegpt.settings.state.YouSettingsState; import ee.carlrobert.llm.client.Client; import ee.carlrobert.llm.client.ProxyAuthenticator; import ee.carlrobert.llm.client.azure.AzureClient; import ee.carlrobert.llm.client.azure.AzureCompletionRequestParams; +import ee.carlrobert.llm.client.llama.LlamaClient; import ee.carlrobert.llm.client.openai.OpenAIClient; import ee.carlrobert.llm.client.you.UTMParameters; import ee.carlrobert.llm.client.you.YouClient; @@ -33,8 +36,16 @@ public class CompletionClientProvider { utmParameters.setMedium("jetbrains"); utmParameters.setCampaign(CodeGPTPlugin.getVersion()); utmParameters.setContent("CodeGPT"); - return new YouClient.Builder(sessionId, accessToken) + // FIXME + return (YouClient) new YouClient.Builder(sessionId, accessToken) .setUTMParameters(utmParameters) + .setHost(YouSettingsState.getInstance().getBaseHost()) + .build(); + } + + public static LlamaClient getLlamaClient() { + return new LlamaClient.Builder() + .setPort(LlamaSettingsState.getInstance().getServerPort()) .build(); } @@ -65,10 +76,9 @@ public class CompletionClientProvider { builder.setProxy( new Proxy(advancedSettings.getProxyType(), new InetSocketAddress(proxyHost, proxyPort))); if (advancedSettings.isProxyAuthSelected()) { - builder.setProxyAuthenticator( - new ProxyAuthenticator( - advancedSettings.getProxyUsername(), - advancedSettings.getProxyPassword())); + builder.setProxyAuthenticator(new ProxyAuthenticator( + advancedSettings.getProxyUsername(), + advancedSettings.getProxyPassword())); } } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java index c1af6ccf..5f0bb7b6 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestHandler.java @@ -80,6 +80,11 @@ public class CompletionRequestHandler { var requestProvider = new CompletionRequestProvider(conversation); try { + if (settings.isUseLlamaService()) { + return CompletionClientProvider.getLlamaClient() + .getChatCompletion(requestProvider.buildLlamaCompletionRequest(message), eventListener); + } + if (settings.isUseYouService()) { var sessionId = ""; var accessToken = ""; diff --git a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java index 95c1d0b6..9fc0adb0 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/CompletionRequestProvider.java @@ -2,19 +2,23 @@ package ee.carlrobert.codegpt.completions; import static java.util.stream.Collectors.toList; +import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.diagnostic.Logger; import ee.carlrobert.codegpt.CodeGPTPlugin; import ee.carlrobert.codegpt.EncodingManager; +import ee.carlrobert.codegpt.completions.llama.LlamaModel; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.ConversationsState; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; import ee.carlrobert.codegpt.settings.state.YouSettingsState; import ee.carlrobert.codegpt.telemetry.core.configuration.TelemetryConfiguration; -import ee.carlrobert.codegpt.telemetry.core.service.TelemetryService; import ee.carlrobert.codegpt.telemetry.core.service.UserId; +import ee.carlrobert.codegpt.util.ApplicationUtils; import ee.carlrobert.embedding.EmbeddingsService; +import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest; import ee.carlrobert.llm.client.openai.completion.chat.OpenAIChatCompletionModel; import ee.carlrobert.llm.client.openai.completion.chat.request.OpenAIChatCompletionMessage; import ee.carlrobert.llm.client.openai.completion.chat.request.OpenAIChatCompletionRequest; @@ -36,17 +40,22 @@ public class CompletionRequestProvider { "Follow the user's requirements carefully & to the letter.\n" + "Your responses should be informative and logical.\n" + "You should always adhere to technical information.\n" + - "If the user asks for code or technical questions, you must provide code suggestions and adhere to technical information.\n" + - "If the question is related to a developer, CodeGPT must respond with content related to a developer.\n" + - "First think step-by-step - describe your plan for what to build in pseudocode, written out in great detail.\n" + + "If the user asks for code or technical questions, you must provide code suggestions and " + + "adhere to technical information.\n" + + "If the question is related to a developer, CodeGPT must respond with " + + "content related to a developer.\n" + + "First think step-by-step - describe your plan for what to build in pseudocode, " + + "written out in great detail.\n" + "Then output the code in a single code block.\n" + "Minimize any other prose.\n" + "Keep your answers short and impersonal.\n" + "Use Markdown formatting in your answers.\n" + - "Make sure to include the programming language name at the start of the Markdown code blocks.\n" + + "Make sure to include the programming language name at the start of the " + + "Markdown code blocks.\n" + "Avoid wrapping the whole response in triple backticks.\n" + - "The user works in an IDE built by JetBrains which has a concept for editors with open files, integrated unit test support, " + - "and output pane that shows the output of running the code as well as an integrated terminal.\n" + + "The user works in an IDE built by JetBrains which has a concept for editors " + + "with open files, integrated unit test support, and output pane that shows " + + "the output of running the code as well as an integrated terminal.\n" + "You can only give one reply for each conversation turn."; private final EncodingManager encodingManager = EncodingManager.getInstance(); @@ -60,6 +69,20 @@ public class CompletionRequestProvider { this.conversation = conversation; } + public LlamaCompletionRequest buildLlamaCompletionRequest(Message message) { + var settings = LlamaSettingsState.getInstance(); + var promptTemplate = settings.isUseCustomModel() ? + settings.getPromptTemplate() : + LlamaModel.findByHuggingFaceModel(settings.getHuggingFaceModel()).getPromptTemplate(); + var prompt = promptTemplate.buildPrompt( + COMPLETION_SYSTEM_PROMPT, + message.getPrompt(), + conversation.getMessages()); + return new LlamaCompletionRequest.Builder(prompt) + .setN_predict(512) + .build(); + } + public YouCompletionRequest buildYouCompletionRequest(Message message) { var requestBuilder = new YouCompletionRequest.Builder(message.getPrompt()) .setUseGPT4Model(YouSettingsState.getInstance().isUseGPT4Model()) @@ -68,7 +91,8 @@ public class CompletionRequestProvider { prevMessage.getPrompt(), prevMessage.getResponse())) .collect(toList())); - if (TelemetryConfiguration.getInstance().isEnabled()) { + if (TelemetryConfiguration.getInstance().isEnabled() && + !ApplicationManager.getApplication().isUnitTestMode()) { requestBuilder.setUserId(UUID.fromString(UserId.INSTANCE.get())); } return requestBuilder.build(); diff --git a/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java b/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java new file mode 100644 index 00000000..a3a8784e --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/HuggingFaceModel.java @@ -0,0 +1,85 @@ +package ee.carlrobert.codegpt.completions; + +import static java.lang.String.format; + +import java.net.MalformedURLException; +import java.net.URL; + +public enum HuggingFaceModel { + + CODE_LLAMA_7B_Q3(7, 3, "CodeLlama-7B-Instruct-GGUF"), + CODE_LLAMA_7B_Q4(7, 4, "CodeLlama-7B-Instruct-GGUF"), + CODE_LLAMA_7B_Q5(7, 5, "CodeLlama-7B-Instruct-GGUF"), + CODE_LLAMA_13B_Q3(13, 3, "CodeLlama-13B-Instruct-GGUF"), + CODE_LLAMA_13B_Q4(13, 4, "CodeLlama-13B-Instruct-GGUF"), + CODE_LLAMA_13B_Q5(13, 5, "CodeLlama-13B-Instruct-GGUF"), + CODE_LLAMA_34B_Q3(34, 3, "CodeLlama-34B-Instruct-GGUF"), + CODE_LLAMA_34B_Q4(34, 4, "CodeLlama-34B-Instruct-GGUF"), + CODE_LLAMA_34B_Q5(34, 5, "CodeLlama-34B-Instruct-GGUF"), + + CODE_BOOGA_34B_Q3(34, 3, "CodeBooga-34B-v0.1-GGUF"), + CODE_BOOGA_34B_Q4(34, 4, "CodeBooga-34B-v0.1-GGUF"), + CODE_BOOGA_34B_Q5(34, 5, "CodeBooga-34B-v0.1-GGUF"), + + PHIND_CODE_LLAMA_34B_Q3(34, 3, "Phind-CodeLlama-34B-v2-GGUF"), + PHIND_CODE_LLAMA_34B_Q4(34, 4, "Phind-CodeLlama-34B-v2-GGUF"), + PHIND_CODE_LLAMA_34B_Q5(34, 5, "Phind-CodeLlama-34B-v2-GGUF"), + + WIZARD_CODER_PYTHON_7B_Q3(7, 3, "WizardCoder-Python-7B-V1.0-GGUF"), + WIZARD_CODER_PYTHON_7B_Q4(7, 4, "WizardCoder-Python-7B-V1.0-GGUF"), + WIZARD_CODER_PYTHON_7B_Q5(7, 5, "WizardCoder-Python-7B-V1.0-GGUF"), + WIZARD_CODER_PYTHON_13B_Q3(13, 3, "WizardCoder-Python-13B-V1.0-GGUF"), + WIZARD_CODER_PYTHON_13B_Q4(13, 4, "WizardCoder-Python-13B-V1.0-GGUF"), + WIZARD_CODER_PYTHON_13B_Q5(13, 5, "WizardCoder-Python-13B-V1.0-GGUF"), + WIZARD_CODER_PYTHON_34B_Q3(34, 3, "WizardCoder-Python-34B-V1.0-GGUF"), + WIZARD_CODER_PYTHON_34B_Q4(34, 4, "WizardCoder-Python-34B-V1.0-GGUF"), + WIZARD_CODER_PYTHON_34B_Q5(34, 5, "WizardCoder-Python-34B-V1.0-GGUF"); + + private final int parameterSize; + private final int quantization; + private final String modelName; + + HuggingFaceModel(int parameterSize, int quantization, String modelName) { + this.parameterSize = parameterSize; + this.quantization = quantization; + this.modelName = modelName; + } + + public int getParameterSize() { + return parameterSize; + } + + public int getQuantization() { + return quantization; + } + + public String getCode() { + return name(); + } + + public String getFileName() { + return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization)); + } + + public URL getFileURL() { + try { + return new URL( + format("https://huggingface.co/TheBloke/%s/resolve/main/%s", modelName, getFileName())); + } catch (MalformedURLException ex) { + throw new RuntimeException(ex); + } + } + + public URL getHuggingFaceURL() { + try { + return new URL("https://huggingface.co/TheBloke/" + modelName); + } catch (MalformedURLException ex) { + throw new RuntimeException(ex); + } + } + + @Override + public String toString() { + return format("%d-bit precision", quantization); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java new file mode 100644 index 00000000..323749fc --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaModel.java @@ -0,0 +1,124 @@ +package ee.carlrobert.codegpt.completions.llama; + +import static java.lang.String.format; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toSet; + +import ee.carlrobert.codegpt.completions.HuggingFaceModel; +import java.util.Collections; +import java.util.List; +import org.jetbrains.annotations.NotNull; + +public enum LlamaModel { + CODE_LLAMA( + "Code Llama", + "Code Llama is a family of large language models for code based on Llama 2 providing state-of-the-art performance among open models, infilling capabilities, support for large input contexts, and zero-shot instruction following ability for programming tasks.", + PromptTemplate.LLAMA, + List.of( + HuggingFaceModel.CODE_LLAMA_7B_Q3, + HuggingFaceModel.CODE_LLAMA_7B_Q4, + HuggingFaceModel.CODE_LLAMA_7B_Q5, + HuggingFaceModel.CODE_LLAMA_13B_Q3, + HuggingFaceModel.CODE_LLAMA_13B_Q4, + HuggingFaceModel.CODE_LLAMA_13B_Q5, + HuggingFaceModel.CODE_LLAMA_34B_Q3, + HuggingFaceModel.CODE_LLAMA_34B_Q4, + HuggingFaceModel.CODE_LLAMA_34B_Q5) + ), + CODE_BOOGA( + "CodeBooga", + "CodeBooga is a high-performing code instruct model created by merging two existing code models:
  1. Phind-CodeLlama-34B-v2
  2. WizardCoder-Python-34B-V1.0
", + PromptTemplate.ALPACA, + List.of( + HuggingFaceModel.CODE_BOOGA_34B_Q3, + HuggingFaceModel.CODE_BOOGA_34B_Q4, + HuggingFaceModel.CODE_BOOGA_34B_Q5)), + PHIND_CODE_LLAMA( + "Phind Code Llama", + "This model is fine-tuned from Phind-CodeLlama-34B-v1 on an additional 1.5B tokens high-quality programming-related data, achieving 73.8% pass@1 on HumanEval. It's the current state-of-the-art amongst open-source models.", + PromptTemplate.ALPACA, + List.of( + HuggingFaceModel.PHIND_CODE_LLAMA_34B_Q3, + HuggingFaceModel.PHIND_CODE_LLAMA_34B_Q4, + HuggingFaceModel.PHIND_CODE_LLAMA_34B_Q5)), + WIZARD_CODER_PYTHON( + "WizardCoder - Python", + "WizardCoder, a Code Evol-Instruct fine-tuned Code LLM, which achieves the 73.2 pass@1 and surpasses GPT4 (2023/03/15), ChatGPT-3.5, and Claude2 on the HumanEval Benchmarks.", + PromptTemplate.ALPACA, + List.of( + HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q3, + HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q4, + HuggingFaceModel.WIZARD_CODER_PYTHON_7B_Q5, + HuggingFaceModel.WIZARD_CODER_PYTHON_13B_Q3, + HuggingFaceModel.WIZARD_CODER_PYTHON_13B_Q4, + HuggingFaceModel.WIZARD_CODER_PYTHON_13B_Q5, + HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q3, + HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q4, + HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q5)); + + private final String label; + private final String description; + private final PromptTemplate promptTemplate; + private final List huggingFaceModels; + + LlamaModel( + String label, + String description, + PromptTemplate promptTemplate, + List huggingFaceModels) { + this.label = label; + this.description = description; + this.promptTemplate = promptTemplate; + this.huggingFaceModels = huggingFaceModels; + } + + public static @NotNull LlamaModel findByHuggingFaceModel(HuggingFaceModel huggingFaceModel) { + for (var llamaModel : LlamaModel.values()) { + if (llamaModel.getHuggingFaceModels().contains(huggingFaceModel)) { + return llamaModel; + } + } + + throw new RuntimeException("Unable to find correct LLM"); + } + + @Override + public String toString() { + return String.join(" ", label, getFormattedModelSizeRange()); + } + + public String getLabel() { + return label; + } + + public String getDescription() { + return description; + } + + public PromptTemplate getPromptTemplate() { + return promptTemplate; + } + + public List getHuggingFaceModels() { + return huggingFaceModels; + } + + public String getFormattedModelSizeRange() { + var parameters = huggingFaceModels.stream() + .map(HuggingFaceModel::getParameterSize) + .collect(toSet()); + if (parameters.size() == 1) { + return parameters.iterator().next() + "B"; + } + return format("(%dB - %dB)", Collections.min(parameters), Collections.max(parameters)); + } + + public List getSortedUniqueModelSizes() { + return huggingFaceModels.stream() + .map(HuggingFaceModel::getParameterSize) + .collect(toSet()) + .stream() + .sorted() + .collect(toList()); + } +} \ No newline at end of file diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaServerAgent.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaServerAgent.java new file mode 100644 index 00000000..03e1561d --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaServerAgent.java @@ -0,0 +1,156 @@ +package ee.carlrobert.codegpt.completions.llama; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.intellij.execution.ExecutionException; +import com.intellij.execution.configurations.GeneralCommandLine; +import com.intellij.execution.process.OSProcessHandler; +import com.intellij.execution.process.ProcessAdapter; +import com.intellij.execution.process.ProcessEvent; +import com.intellij.execution.process.ProcessListener; +import com.intellij.execution.process.ProcessOutputType; +import com.intellij.icons.AllIcons.Actions; +import com.intellij.openapi.Disposable; +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.components.Service; +import com.intellij.openapi.diagnostic.Logger; +import com.intellij.openapi.util.Key; +import com.intellij.ui.components.JBLabel; +import ee.carlrobert.codegpt.CodeGPTPlugin; +import ee.carlrobert.codegpt.settings.service.ServerProgressPanel; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; +import java.nio.charset.StandardCharsets; +import javax.swing.SwingConstants; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +@Service +public final class LlamaServerAgent implements Disposable { + + private static final Logger LOG = Logger.getInstance(LlamaServerAgent.class); + + private static @Nullable OSProcessHandler makeProcessHandler; + private static @Nullable OSProcessHandler startServerProcessHandler; + + public void startAgent( + String modelPath, + int contextLength, + int port, + ServerProgressPanel serverProgressPanel, + Runnable onSuccess) { + ApplicationManager.getApplication().invokeLater(() -> { + try { + serverProgressPanel.updateText("Building llama.cpp..."); + makeProcessHandler = new OSProcessHandler(getMakeCommandLinde()); + makeProcessHandler.addProcessListener( + getMakeProcessListener(modelPath, contextLength, port, serverProgressPanel, onSuccess)); + makeProcessHandler.startNotify(); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + }); + } + + public void stopAgent() { + if (startServerProcessHandler != null) { + startServerProcessHandler.destroyProcess(); + } + } + + public boolean isServerRunning() { + return startServerProcessHandler != null && + startServerProcessHandler.isStartNotified() && + !startServerProcessHandler.isProcessTerminated(); + } + + private ProcessListener getMakeProcessListener( + String modelPath, + int contextLength, + int port, + ServerProgressPanel serverProgressPanel, + Runnable onSuccess) { + return new ProcessAdapter() { + @Override + public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) { + LOG.info(event.getText()); + } + + @Override + public void processTerminated(@NotNull ProcessEvent event) { + try { + serverProgressPanel.updateText("Booting up server..."); + startServerProcessHandler = new OSProcessHandler( + getServerCommandLine(modelPath, contextLength, port)); + startServerProcessHandler.addProcessListener( + getProcessListener(port, serverProgressPanel, onSuccess)); + startServerProcessHandler.startNotify(); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + }; + } + + private ProcessListener getProcessListener( + int port, + ServerProgressPanel serverProgressPanel, + Runnable onSuccess) { + return new ProcessAdapter() { + private final ObjectMapper objectMapper = new ObjectMapper(); + + @Override + public void processTerminated(@NotNull ProcessEvent event) { + serverProgressPanel.displayComponent(new JBLabel( + "Server terminated", + Actions.Cancel, + SwingConstants.LEADING)); + } + + @Override + public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) { + LOG.debug(event.getText()); + + if (outputType == ProcessOutputType.STDOUT) { + try { + var serverMessage = objectMapper.readValue(event.getText(), LlamaServerMessage.class); + if ("HTTP server listening".equals(serverMessage.getMessage())) { + LlamaSettingsState.getInstance().setServerPort(port); + onSuccess.run(); + } + } catch (Exception ignore) { + } + } + } + }; + } + + private static GeneralCommandLine getMakeCommandLinde() { + GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8); + commandLine.setExePath("make"); + commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath()); + commandLine.addParameters("-j"); + commandLine.setRedirectErrorStream(false); + return commandLine; + } + + private GeneralCommandLine getServerCommandLine(String modelPath, int contextLength, int port) { + GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8); + commandLine.setExePath("./server"); + commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath()); + commandLine.addParameters( + "-m", modelPath, + "-c", String.valueOf(contextLength), + "--port", String.valueOf(port)); + commandLine.setRedirectErrorStream(false); + return commandLine; + } + + @Override + public void dispose() { + if (makeProcessHandler != null && !makeProcessHandler.isProcessTerminated()) { + makeProcessHandler.destroyProcess(); + } + if (startServerProcessHandler != null && !startServerProcessHandler.isProcessTerminated()) { + startServerProcessHandler.destroyProcess(); + } + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaServerMessage.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaServerMessage.java new file mode 100644 index 00000000..3912f9d2 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/LlamaServerMessage.java @@ -0,0 +1,26 @@ +package ee.carlrobert.codegpt.completions.llama; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class LlamaServerMessage { + + private final String level; + private final String message; + + public LlamaServerMessage( + @JsonProperty("level") String level, + @JsonProperty("message") String message) { + this.level = level; + this.message = message; + } + + public String getLevel() { + return level; + } + + public String getMessage() { + return message; + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java b/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java new file mode 100644 index 00000000..74e71b22 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/llama/PromptTemplate.java @@ -0,0 +1,113 @@ +package ee.carlrobert.codegpt.completions.llama; + +import ee.carlrobert.codegpt.conversations.message.Message; +import java.util.List; + +public enum PromptTemplate { + + CHAT_ML("Chat Markup Language (ChatML)") { + @Override + public String buildPrompt(String systemPrompt, String userPrompt, List history) { + StringBuilder prompt = new StringBuilder(); + + if (systemPrompt != null && !systemPrompt.isEmpty()) { + prompt.append("<|im_start|>system\n") + .append(systemPrompt) + .append("<|im_end|>\n"); + } + + for (Message message : history) { + prompt.append("<|im_start|>user\n") + .append(message.getPrompt()) + .append("<|im_end|>\n") + .append("<|im_start|>assistant\n") + .append(message.getResponse()) + .append("<|im_end|>\n"); + } + + return prompt.append("<|im_start|>user\n") + .append(userPrompt) + .append("<|im_end|>") + .toString(); + } + }, + LLAMA("Llama") { + @Override + public String buildPrompt(String systemPrompt, String userPrompt, List history) { + StringBuilder prompt = new StringBuilder(); + + if (systemPrompt != null && !systemPrompt.isEmpty()) { + prompt.append("<>") + .append(systemPrompt) + .append("<>\n"); + } + + for (Message message : history) { + prompt.append("[INST]") + .append(message.getPrompt()) + .append("[/INST]\n") + .append(message.getResponse()).append("\n"); + } + + return prompt.append("[INST]") + .append(userPrompt) + .append("[/INST]") + .toString(); + } + }, + TORA("ToRA") { + @Override + public String buildPrompt(String systemPrompt, String userPrompt, List history) { + StringBuilder prompt = new StringBuilder(); + + for (Message message : history) { + prompt.append("<|user|>\n") + .append(message.getPrompt()) + .append("\n<|assistant|>\n") + .append(message.getResponse()).append("\n"); + } + + return prompt.append("<|user|>\n") + .append(userPrompt) + .append("\n<|assistant|>") + .toString(); + } + }, + ALPACA("Alpaca/Vicuna") { + @Override + public String buildPrompt(String systemPrompt, String userPrompt, List history) { + StringBuilder prompt = new StringBuilder(); + + prompt.append( + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"); + + for (Message message : history) { + prompt.append("### Instruction\n") + .append(message.getPrompt()) + .append("\n\n") + .append("### Response:\n") + .append(message.getResponse()) + .append("\n\n"); + } + + return prompt.append("### Instruction\n") + .append(userPrompt) + .append("\n\n") + .append("### Response:\n") + .toString(); + } + }; + + private final String label; + + PromptTemplate(String label) { + this.label = label; + } + + public abstract String buildPrompt(String systemPrompt, String userPrompt, List history); + + @Override + public String toString() { + return label; + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/you/YouApiClient.java b/src/main/java/ee/carlrobert/codegpt/completions/you/YouApiClient.java index 4d4f1918..6e6b2379 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/you/YouApiClient.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/you/YouApiClient.java @@ -1,51 +1,56 @@ package ee.carlrobert.codegpt.completions.you; -import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; -import java.nio.charset.StandardCharsets; -import java.util.Base64; -import java.util.Map; -import okhttp3.Callback; +import ee.carlrobert.codegpt.completions.you.auth.response.YouAuthenticationResponse; +import java.io.IOException; +import java.util.List; import okhttp3.Headers; import okhttp3.OkHttpClient; import okhttp3.Request; -import okhttp3.RequestBody; +import org.jetbrains.annotations.Nullable; @Service public final class YouApiClient { - private static final String API_BASE_URL = "https://web.stytch.com/sdk"; - private static final String publicToken = "public-token-live-507a52ad-7e69-496b-aee0-1c9863c7c819"; + private static final String API_BASE_URL = "https://you.com/api"; public static YouApiClient getInstance() { return ApplicationManager.getApplication().getService(YouApiClient.class); } - public void authenticate(String email, String password, Callback callback) { - try { - new OkHttpClient() - .newCall(new Request.Builder() - .url(API_BASE_URL + "/v1/passwords/authenticate") - .headers(Headers.of( - "content-type", "application/json", - "authority", "web.stytch.com", - "authorization", "Basic " + Base64.getEncoder().encodeToString((publicToken + ":" + publicToken).getBytes()), - "x-sdk-client", "eyJldmVudF9pZCI6ImV2ZW50LWlkLWY5YmU4YWU5LWE3MjctNGFlYy1hNzY0LTk4NDg1NDFkZjcwYSIsImFwcF9zZXNzaW9uX2lkIjoiYXBwLXNlc3Npb24taWQtYjY1NzcwZjMtMWFkMy00YjlhLWFjYzctMzJjNWQyMGMxNGU0IiwicGVyc2lzdGVudF9pZCI6InBlcnNpc3RlbnQtaWQtYzY0M2M0YTMtZDg5MC00ZGJkLTk3YjQtMjY0MmFlODdkMTZhIiwiY2xpZW50X3NlbnRfYXQiOiIyMDIzLTA5LTAxVDIyOjMwOjU1LjIzNFoiLCJ0aW1lem9uZSI6IkV1cm9wZS9UYWxsaW5uIiwiYXBwIjp7ImlkZW50aWZpZXIiOiJ5b3UuY29tIn0sInNkayI6eyJpZGVudGlmaWVyIjoiU3R5dGNoLmpzIEphdmFzY3JpcHQgU0RLIC0gWU9VLkNPTSBERUJVRyBCVUlMRCIsInZlcnNpb24iOiI0LjAuMCJ9fQ==", - "x-sdk-parent-host", "https://you.com" - )) - .post(RequestBody.create(new ObjectMapper() - .writerWithDefaultPrettyPrinter() - .writeValueAsString(Map.of( - "email", email, - "password", password, - "session_duration_minutes", 129_600)) - .getBytes(StandardCharsets.UTF_8))) - .build()) - .enqueue(callback); - } catch (JsonProcessingException e) { - throw new RuntimeException("Could not process request", e); + public @Nullable YouSubscription getSubscription(YouAuthenticationResponse auth) { + var sessionId = auth.getData().getSession().getSessionId(); + var sessionJwt = auth.getData().getSessionJwt(); + var request = new Request.Builder() + .url(API_BASE_URL + "/payments/orders/subscriptions/current") + .header("Accept", "application/json") + .header("Cache-Control", "no-cache") + .header("User-Agent", "youide CodeGPT") + .header("Cookie", ( + "stytch_session=" + sessionId + "; " + + "ydc_stytch_session=" + sessionId + "; " + + "stytch_session_jwt=" + sessionJwt + "; " + + "ydc_stytch_session_jwt=" + sessionJwt + "; ")) + .get() + .build(); + + try (var response = new OkHttpClient().newCall(request).execute()) { + var body = response.body(); + if (body == null || !response.isSuccessful()) { + return null; + } + List subscriptions = + new ObjectMapper().readValue(body.string(), new TypeReference<>() { + }); + if (subscriptions == null || subscriptions.isEmpty()) { + return null; + } + return subscriptions.get(0); + } catch (IOException ex) { + throw new RuntimeException("Could not get You.com subscription", ex); } } } diff --git a/src/main/java/ee/carlrobert/codegpt/completions/you/YouSubscription.java b/src/main/java/ee/carlrobert/codegpt/completions/you/YouSubscription.java new file mode 100644 index 00000000..6a2ba6b9 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/you/YouSubscription.java @@ -0,0 +1,34 @@ +package ee.carlrobert.codegpt.completions.you; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class YouSubscription { + + private final String service; + private final String tier; + private final String month; + + public YouSubscription( + @JsonProperty("service") String service, + @JsonProperty("tier") String tier, + @JsonProperty("month") String month) { + this.service = service; + this.tier = tier; + this.month = month; + } + + public String getService() { + return service; + } + + public String getTier() { + return tier; + } + + public String getMonth() { + return month; + } +} + diff --git a/src/main/java/ee/carlrobert/codegpt/completions/you/YouSubscriptionNotifier.java b/src/main/java/ee/carlrobert/codegpt/completions/you/YouSubscriptionNotifier.java new file mode 100644 index 00000000..1fc16a50 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/you/YouSubscriptionNotifier.java @@ -0,0 +1,11 @@ +package ee.carlrobert.codegpt.completions.you; + +import com.intellij.util.messages.Topic; + +public interface YouSubscriptionNotifier { + + Topic SUBSCRIPTION_TOPIC = + Topic.create("subscriptionTopic", YouSubscriptionNotifier.class); + + void subscribed(); +} \ No newline at end of file diff --git a/src/main/java/ee/carlrobert/codegpt/completions/you/YouUserManager.java b/src/main/java/ee/carlrobert/codegpt/completions/you/YouUserManager.java index ac010c2c..219c8511 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/you/YouUserManager.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/you/YouUserManager.java @@ -9,6 +9,7 @@ import ee.carlrobert.codegpt.completions.you.auth.response.YouAuthenticationResp public final class YouUserManager { private YouAuthenticationResponse authenticationResponse; + private boolean subscribed; private YouUserManager() { } @@ -27,14 +28,19 @@ public final class YouUserManager { public void clearSession() { authenticationResponse = null; + subscribed = false; ApplicationManager.getApplication().getMessageBus() .syncPublisher(SignedOutNotifier.SIGNED_OUT_TOPIC) .signedOut(); } + public void setSubscribed(boolean subscribed) { + this.subscribed = subscribed; + } + public boolean isSubscribed() { - return true; // TODO + return subscribed; } public boolean isAuthenticated() { diff --git a/src/main/java/ee/carlrobert/codegpt/completions/you/auth/YouAuthClient.java b/src/main/java/ee/carlrobert/codegpt/completions/you/auth/YouAuthClient.java new file mode 100644 index 00000000..b4ddcab4 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/completions/you/auth/YouAuthClient.java @@ -0,0 +1,51 @@ +package ee.carlrobert.codegpt.completions.you.auth; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.components.Service; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Map; +import okhttp3.Callback; +import okhttp3.Headers; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; + +@Service +public final class YouAuthClient { + + private static final String API_BASE_URL = "https://web.stytch.com/sdk"; + private static final String publicToken = "public-token-live-507a52ad-7e69-496b-aee0-1c9863c7c819"; + + public static YouAuthClient getInstance() { + return ApplicationManager.getApplication().getService(YouAuthClient.class); + } + + public void authenticate(String email, String password, Callback callback) { + try { + new OkHttpClient() + .newCall(new Request.Builder() + .url(API_BASE_URL + "/v1/passwords/authenticate") + .headers(Headers.of( + "content-type", "application/json", + "authority", "web.stytch.com", + "authorization", "Basic " + Base64.getEncoder().encodeToString((publicToken + ":" + publicToken).getBytes()), + "x-sdk-client", "eyJldmVudF9pZCI6ImV2ZW50LWlkLWY5YmU4YWU5LWE3MjctNGFlYy1hNzY0LTk4NDg1NDFkZjcwYSIsImFwcF9zZXNzaW9uX2lkIjoiYXBwLXNlc3Npb24taWQtYjY1NzcwZjMtMWFkMy00YjlhLWFjYzctMzJjNWQyMGMxNGU0IiwicGVyc2lzdGVudF9pZCI6InBlcnNpc3RlbnQtaWQtYzY0M2M0YTMtZDg5MC00ZGJkLTk3YjQtMjY0MmFlODdkMTZhIiwiY2xpZW50X3NlbnRfYXQiOiIyMDIzLTA5LTAxVDIyOjMwOjU1LjIzNFoiLCJ0aW1lem9uZSI6IkV1cm9wZS9UYWxsaW5uIiwiYXBwIjp7ImlkZW50aWZpZXIiOiJ5b3UuY29tIn0sInNkayI6eyJpZGVudGlmaWVyIjoiU3R5dGNoLmpzIEphdmFzY3JpcHQgU0RLIC0gWU9VLkNPTSBERUJVRyBCVUlMRCIsInZlcnNpb24iOiI0LjAuMCJ9fQ==", + "x-sdk-parent-host", "https://you.com" + )) + .post(RequestBody.create(new ObjectMapper() + .writerWithDefaultPrettyPrinter() + .writeValueAsString(Map.of( + "email", email, + "password", password, + "session_duration_minutes", 129_600)) + .getBytes(StandardCharsets.UTF_8))) + .build()) + .enqueue(callback); + } catch (JsonProcessingException e) { + throw new RuntimeException("Could not process request", e); + } + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/completions/you/auth/YouAuthenticationService.java b/src/main/java/ee/carlrobert/codegpt/completions/you/auth/YouAuthenticationService.java index 869f0313..687d9b17 100644 --- a/src/main/java/ee/carlrobert/codegpt/completions/you/auth/YouAuthenticationService.java +++ b/src/main/java/ee/carlrobert/codegpt/completions/you/auth/YouAuthenticationService.java @@ -6,6 +6,7 @@ import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; import com.intellij.openapi.diagnostic.Logger; import ee.carlrobert.codegpt.completions.you.YouApiClient; +import ee.carlrobert.codegpt.completions.you.YouSubscriptionNotifier; import ee.carlrobert.codegpt.completions.you.YouUserManager; import ee.carlrobert.codegpt.completions.you.auth.response.YouAuthenticationResponse; import ee.carlrobert.codegpt.util.OverlayUtils; @@ -19,7 +20,7 @@ import org.jetbrains.annotations.NotNull; public final class YouAuthenticationService { private static final Logger LOG = Logger.getInstance(YouAuthenticationService.class); - private static final YouApiClient client = YouApiClient.getInstance(); + private static final YouAuthClient authClient = YouAuthClient.getInstance(); private YouAuthenticationService() { } @@ -28,8 +29,9 @@ public final class YouAuthenticationService { return ApplicationManager.getApplication().getService(YouAuthenticationService.class); } - public void signInAsync(String email, String password, AuthenticationHandler authenticationHandler) { - client.authenticate(email, password, new AuthenticationCallback(authenticationHandler)); + public void signInAsync(String email, String password, + AuthenticationHandler authenticationHandler) { + authClient.authenticate(email, password, new AuthenticationCallback(authenticationHandler)); } static class AuthenticationCallback implements Callback { @@ -56,11 +58,23 @@ public final class YouAuthenticationService { if (response.code() == 200) { try { - var authenticationResponse = new ObjectMapper().readValue(body.string(), YouAuthenticationResponse.class); - YouUserManager.getInstance().setAuthenticationResponse(authenticationResponse); + var messageBus = ApplicationManager.getApplication().getMessageBus(); + var userManager = YouUserManager.getInstance(); + + var authenticationResponse = + new ObjectMapper().readValue(body.string(), YouAuthenticationResponse.class); + userManager.setAuthenticationResponse(authenticationResponse); authenticationHandler.handleAuthenticated(authenticationResponse); - ApplicationManager.getApplication().getMessageBus() + var subscription = + YouApiClient.getInstance().getSubscription(authenticationResponse); + var subscribed = subscription != null && "youpro".equals(subscription.getService()); + userManager.setSubscribed(subscribed); + if (subscribed) { + messageBus.syncPublisher(YouSubscriptionNotifier.SUBSCRIPTION_TOPIC).subscribed(); + } + + messageBus .syncPublisher(AuthenticationNotifier.AUTHENTICATION_TOPIC) .authenticationSuccessful(); return; @@ -70,7 +84,8 @@ public final class YouAuthenticationService { } try { - authenticationHandler.handleError(new ObjectMapper().readValue(body.string(), YouAuthenticationError.class)); + authenticationHandler.handleError( + new ObjectMapper().readValue(body.string(), YouAuthenticationError.class)); } catch (Throwable ex) { authenticationHandler.handleGenericError(); throw new RuntimeException(ex); diff --git a/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java b/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java index 7e4ecbaa..bac5ce76 100644 --- a/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java +++ b/src/main/java/ee/carlrobert/codegpt/conversations/ConversationService.java @@ -6,6 +6,7 @@ import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.components.Service; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; import java.time.LocalDateTime; @@ -46,8 +47,10 @@ public final class ConversationService { conversation.setModel("YouCode"); } else if (settings.isUseAzureService()) { conversation.setModel(AzureSettingsState.getInstance().getModel()); - } else { + } else if (settings.isUseOpenAIService()) { conversation.setModel(OpenAISettingsState.getInstance().getModel()); + } else { + conversation.setModel(LlamaSettingsState.getInstance().getHuggingFaceModel().getCode()); } conversation.setCreatedOn(LocalDateTime.now()); conversation.setUpdatedOn(LocalDateTime.now()); @@ -64,7 +67,11 @@ public final class ConversationService { conversationsMapping.put(conversation.getClientCode(), conversations); } - public void saveMessage(String response, Message message, Conversation conversation, boolean isRetry) { + public void saveMessage( + String response, + Message message, + Conversation conversation, + boolean isRetry) { var conversationMessages = conversation.getMessages(); if (isRetry && !conversationMessages.isEmpty()) { var messageToBeSaved = conversationMessages.stream() @@ -122,6 +129,9 @@ public final class ConversationService { if (settings.isUseAzureService()) { return "azure.chat.completion"; } + if (settings.isUseLlamaService()) { + return "llama.chat.completion"; + } return "you.chat.completion"; } diff --git a/src/main/java/ee/carlrobert/codegpt/indexes/FolderStructureTreePanel.java b/src/main/java/ee/carlrobert/codegpt/indexes/FolderStructureTreePanel.java index 918409a2..2147b563 100644 --- a/src/main/java/ee/carlrobert/codegpt/indexes/FolderStructureTreePanel.java +++ b/src/main/java/ee/carlrobert/codegpt/indexes/FolderStructureTreePanel.java @@ -19,8 +19,8 @@ import com.intellij.ui.ScrollPaneFactory; import com.intellij.ui.components.JBLabel; import com.intellij.util.ui.AsyncProcessIcon; import com.intellij.util.ui.JBUI; -import ee.carlrobert.embedding.CheckedFile; import ee.carlrobert.codegpt.util.file.FileUtils; +import ee.carlrobert.embedding.CheckedFile; import java.awt.BorderLayout; import java.awt.FlowLayout; import java.awt.event.MouseAdapter; @@ -123,7 +123,7 @@ public class FolderStructureTreePanel { panel.add(loadingFilesSpinner); } else { panel.add(new JBLabel("Total size: " + - convertFileSize(totalSize) + " ~ " + + FileUtils.convertFileSize(totalSize) + " ~ " + (convertLongValue(totalSize / 4)) + " tokens " + " ~ " + new DecimalFormat("#.##").format(((double) (totalSize / 4) / 1000) * 0.0001) + " $")); } @@ -137,7 +137,9 @@ public class FolderStructureTreePanel { } private List getCheckedVirtualFiles() { - return Arrays.stream(checkboxTree.getCheckedNodes(VirtualFileSystemEntry.class, node -> node instanceof VirtualFileImpl)) + return Arrays.stream(checkboxTree.getCheckedNodes( + VirtualFileSystemEntry.class, + node -> node instanceof VirtualFileImpl)) .map(entry -> (VirtualFileImpl) entry) .collect(toList()); } @@ -160,12 +162,15 @@ public class FolderStructureTreePanel { } } - private void traverseDirectory(@NotNull CheckedTreeNode parentNode, @NotNull VirtualFile projectDirectory) { + private void traverseDirectory(@NotNull CheckedTreeNode parentNode, + @NotNull VirtualFile projectDirectory) { for (VirtualFile childFile : projectDirectory.getChildren()) { var node = new CheckedTreeNode(childFile); parentNode.add(node); - if (!parentNode.isChecked() || ignoredFileDirectories.parallelStream().anyMatch(it -> it.equalsIgnoreCase(childFile.getName()))) { + var potentiallyIgnored = ignoredFileDirectories.parallelStream() + .anyMatch(it -> it.equalsIgnoreCase(childFile.getName())); + if (!parentNode.isChecked() || potentiallyIgnored) { node.setChecked(false); } @@ -180,7 +185,13 @@ public class FolderStructureTreePanel { private @NotNull CheckboxTree.CheckboxTreeCellRenderer createFileTypesRenderer() { return new CheckboxTree.CheckboxTreeCellRenderer() { @Override - public void customizeRenderer(JTree t, Object value, boolean selected, boolean expanded, boolean leaf, int row, boolean focus) { + public void customizeRenderer(JTree t, + Object value, + boolean selected, + boolean expanded, + boolean leaf, + int row, + boolean focus) { if (!(value instanceof CheckedTreeNode)) { return; } @@ -194,28 +205,18 @@ public class FolderStructureTreePanel { if (userObject instanceof VirtualDirectoryImpl) { getTextRenderer().setIcon(AllIcons.Nodes.Folder); } else { - var fileType = FileTypeManager.getInstance().getFileTypeByFile((VirtualFileSystemEntry) userObject); + var fileType = FileTypeManager.getInstance() + .getFileTypeByFile((VirtualFileSystemEntry) userObject); getTextRenderer().setIcon(fileType.getIcon()); - getTextRenderer().append(" - " + convertFileSize(((VirtualFileSystemEntry) userObject).getLength())); + getTextRenderer().append( + " - " + FileUtils.convertFileSize( + ((VirtualFileSystemEntry) userObject).getLength())); } } } }; } - private static String convertFileSize(long fileSizeInBytes) { - String[] units = {"B", "KB", "MB", "GB"}; - int unitIndex = 0; - double fileSize = fileSizeInBytes; - - while (fileSize >= 1024 && unitIndex < units.length - 1) { - fileSize /= 1024; - unitIndex++; - } - - return new DecimalFormat("#.##").format(fileSize) + " " + units[unitIndex]; - } - private static String convertLongValue(long value) { if (value >= 1_000_000) { return value / 1_000_000 + "M"; diff --git a/src/main/java/ee/carlrobert/codegpt/settings/ModelComboBox.java b/src/main/java/ee/carlrobert/codegpt/settings/ModelComboBox.java deleted file mode 100644 index 7c1cf2eb..00000000 --- a/src/main/java/ee/carlrobert/codegpt/settings/ModelComboBox.java +++ /dev/null @@ -1,30 +0,0 @@ -package ee.carlrobert.codegpt.settings; - -import com.intellij.openapi.ui.ComboBox; -import ee.carlrobert.llm.completion.CompletionModel; -import java.awt.Component; -import javax.swing.JList; -import javax.swing.plaf.basic.BasicComboBoxRenderer; - -public class ModelComboBox extends ComboBox { - - public ModelComboBox(CompletionModel[] options, CompletionModel selectedModel) { - super(options); - setSelectedItem(selectedModel); - setRenderer(getBasicComboBoxRenderer()); - } - - private BasicComboBoxRenderer getBasicComboBoxRenderer() { - return new BasicComboBoxRenderer() { - public Component getListCellRendererComponent(JList list, Object value, int index, boolean isSelected, boolean cellHasFocus) { - super.getListCellRendererComponent(list, value, index, isSelected, cellHasFocus); - - if (value != null) { - CompletionModel model = (CompletionModel) value; - setText(model.getDescription()); - } - return this; - } - }; - } -} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/ServiceChangeNotifier.java b/src/main/java/ee/carlrobert/codegpt/settings/ServiceChangeNotifier.java deleted file mode 100644 index fd729cec..00000000 --- a/src/main/java/ee/carlrobert/codegpt/settings/ServiceChangeNotifier.java +++ /dev/null @@ -1,2 +0,0 @@ -package ee.carlrobert.codegpt.settings;public class ServiceChangeNotifier { -} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/SettingsComponent.java b/src/main/java/ee/carlrobert/codegpt/settings/SettingsComponent.java index 4db65438..b016af88 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/SettingsComponent.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/SettingsComponent.java @@ -1,12 +1,22 @@ package ee.carlrobert.codegpt.settings; +import static java.util.stream.Collectors.toList; + import com.intellij.openapi.Disposable; -import com.intellij.ui.TitledSeparator; +import com.intellij.openapi.ui.ComboBox; +import com.intellij.openapi.ui.ComponentValidator; +import com.intellij.openapi.ui.ValidationInfo; +import com.intellij.openapi.util.SystemInfoRt; import com.intellij.ui.components.JBTextField; import com.intellij.util.ui.FormBuilder; -import com.intellij.util.ui.UI; import ee.carlrobert.codegpt.CodeGPTBundle; +import ee.carlrobert.codegpt.settings.service.ServiceSelectionForm; +import ee.carlrobert.codegpt.settings.service.ServiceType; +import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; +import java.awt.CardLayout; +import java.util.Arrays; +import javax.swing.DefaultComboBoxModel; import javax.swing.JComponent; import javax.swing.JPanel; @@ -14,25 +24,53 @@ public class SettingsComponent { private final JPanel mainPanel; private final JBTextField displayNameField; + private final ComboBox serviceComboBox; private final ServiceSelectionForm serviceSelectionForm; - private final YouServiceSelectionPanel youServiceSelectionPanel; public SettingsComponent(Disposable parentDisposable, SettingsState settings) { - serviceSelectionForm = new ServiceSelectionForm(parentDisposable, settings); displayNameField = new JBTextField(settings.getDisplayName(), 20); - youServiceSelectionPanel = new YouServiceSelectionPanel(parentDisposable); + + serviceSelectionForm = new ServiceSelectionForm(parentDisposable); + var cardLayout = new CardLayout(); + var cards = new JPanel(cardLayout); + cards.add(serviceSelectionForm.getOpenAIServiceSectionPanel(), ServiceType.OPENAI.getCode()); + cards.add(serviceSelectionForm.getAzureServiceSectionPanel(), ServiceType.AZURE.getCode()); + cards.add(serviceSelectionForm.getYouServiceSectionPanel(), ServiceType.YOU.getCode()); + cards.add(serviceSelectionForm.getLlamaServiceSectionPanel(), ServiceType.LLAMA_CPP.getCode()); + var serviceComboBoxModel = new DefaultComboBoxModel(); + serviceComboBoxModel.addAll(Arrays.stream(ServiceType.values()) + .filter(it -> !"LLAMA_CPP".equals(it.getCode()) || SystemInfoRt.isUnix) + .collect(toList())); + serviceComboBox = new ComboBox<>(serviceComboBoxModel); + serviceComboBox.setSelectedItem(ServiceType.OPENAI); + serviceComboBox.setPreferredSize(displayNameField.getPreferredSize()); + var serviceInputValidator = createInputValidator(parentDisposable, serviceComboBox); + serviceInputValidator.revalidate(); + serviceComboBox.addItemListener(e -> { + serviceInputValidator.revalidate(); + cardLayout.show(cards, ((ServiceType) e.getItem()).getCode()); + }); + mainPanel = FormBuilder.createFormBuilder() - .addComponent(UI.PanelFactory.panel(displayNameField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.section.integration.displayNameFieldLabel")) - .resizeX(false) - .createPanel()) - .addComponent(new TitledSeparator(CodeGPTBundle.get("settingsConfigurable.section.service.title"))) - .addComponent(serviceSelectionForm.getForm()) - .addVerticalGap(8) + .addLabeledComponent( + CodeGPTBundle.get("settingsConfigurable.displayName.label"), + displayNameField) + .addLabeledComponent( + CodeGPTBundle.get("settingsConfigurable.service.label"), + serviceComboBox) + .addComponent(cards) .addComponentFillVertically(new JPanel(), 0) .getPanel(); } + public ServiceType getSelectedService() { + return serviceComboBox.getItem(); + } + + public void setSelectedService(ServiceType serviceType) { + serviceComboBox.setSelectedItem(serviceType); + } + public JPanel getPanel() { return mainPanel; } @@ -41,18 +79,6 @@ public class SettingsComponent { return displayNameField; } - public String getEmail() { - return youServiceSelectionPanel.getEmail(); - } - - public void setEmail(String email) { - youServiceSelectionPanel.setEmail(email); - } - - public String getPassword() { - return youServiceSelectionPanel.getPassword(); - } - public ServiceSelectionForm getServiceSelectionForm() { return serviceSelectionForm; } @@ -64,4 +90,27 @@ public class SettingsComponent { public void setDisplayName(String displayName) { displayNameField.setText(displayName); } + + private ComponentValidator createInputValidator( + Disposable parentDisposable, + JComponent component) { + var validator = new ComponentValidator(parentDisposable) + .withValidator(() -> { + if (component instanceof ComboBox) { + var selectedItem = ((ComboBox) component).getSelectedItem(); + if (selectedItem == ServiceType.OPENAI && + OpenAISettingsState.getInstance().isOpenAIQuotaExceeded()) { + return new ValidationInfo( + CodeGPTBundle.get("settings.openaiQuotaExceeded"), + component); + } + } + + return null; + }) + .andStartOnFocusLost() + .installOn(component); + validator.enableValidation(); + return validator; + } } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/SettingsConfigurable.java b/src/main/java/ee/carlrobert/codegpt/settings/SettingsConfigurable.java index 4fce4b7a..641c1951 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/SettingsConfigurable.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/SettingsConfigurable.java @@ -1,5 +1,10 @@ package ee.carlrobert.codegpt.settings; +import static ee.carlrobert.codegpt.settings.service.ServiceType.AZURE; +import static ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP; +import static ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI; +import static ee.carlrobert.codegpt.settings.service.ServiceType.YOU; + import com.intellij.openapi.Disposable; import com.intellij.openapi.options.Configurable; import com.intellij.openapi.util.Disposer; @@ -7,7 +12,9 @@ import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.conversations.ConversationsState; import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; +import ee.carlrobert.codegpt.settings.service.ServiceType; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; import ee.carlrobert.codegpt.settings.state.YouSettingsState; @@ -49,14 +56,24 @@ public class SettingsConfigurable implements Configurable { var settings = SettingsState.getInstance(); var openAISettings = OpenAISettingsState.getInstance(); var azureSettings = AzureSettingsState.getInstance(); + var llamaSettings = LlamaSettingsState.getInstance(); var serviceSelectionForm = settingsComponent.getServiceSelectionForm(); + var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm(); return !settingsComponent.getDisplayName().equals(settings.getDisplayName()) || - isServiceChanged(serviceSelectionForm, settings) || + isServiceChanged(settings) || openAISettings.isModified(serviceSelectionForm) || azureSettings.isModified(serviceSelectionForm) || serviceSelectionForm.isDisplayWebSearchResults() != - YouSettingsState.getInstance().isDisplayWebSearchResults(); + YouSettingsState.getInstance().isDisplayWebSearchResults() || + + llamaSettings.isUseCustomModel() != llamaModelPreferencesForm.isUseCustomLlamaModel() || + llamaSettings.getServerPort() != serviceSelectionForm.getLlamaServerPort() || + llamaSettings.getContextSize() != serviceSelectionForm.getContextSize() || + llamaSettings.getHuggingFaceModel() != llamaModelPreferencesForm.getSelectedModel() || + !llamaSettings.getPromptTemplate().equals(llamaModelPreferencesForm.getPromptTemplate()) || + !llamaSettings.getCustomLlamaModelPath() + .equals(llamaModelPreferencesForm.getCustomLlamaModelPath()); } @Override @@ -65,7 +82,8 @@ public class SettingsConfigurable implements Configurable { var settings = SettingsState.getInstance(); var openAISettings = OpenAISettingsState.getInstance(); var azureSettings = AzureSettingsState.getInstance(); - var serviceChanged = isServiceChanged(serviceSelectionForm, settings); + var llamaSettings = LlamaSettingsState.getInstance(); + var serviceChanged = isServiceChanged(settings); var modelChanged = openAISettings.getModel().equals(serviceSelectionForm.getOpenAIModel()) || azureSettings.getModel().equals(serviceSelectionForm.getAzureModel()); @@ -80,11 +98,21 @@ public class SettingsConfigurable implements Configurable { .setAzureActiveDirectoryToken(serviceSelectionForm.getAzureActiveDirectoryToken()); settings.setDisplayName(settingsComponent.getDisplayName()); - settings.setUseOpenAIService(serviceSelectionForm.isOpenAIServiceSelected()); - settings.setUseAzureService(serviceSelectionForm.isAzureServiceSelected()); - settings.setUseYouService(serviceSelectionForm.isYouServiceSelected()); + // TODO: Store as single enum value + settings.setUseOpenAIService(settingsComponent.getSelectedService() == OPENAI); + settings.setUseAzureService(settingsComponent.getSelectedService() == ServiceType.AZURE); + settings.setUseYouService(settingsComponent.getSelectedService() == ServiceType.YOU); YouSettingsState.getInstance() .setDisplayWebSearchResults(serviceSelectionForm.isDisplayWebSearchResults()); + settings.setUseLlamaService(settingsComponent.getSelectedService() == ServiceType.LLAMA_CPP); + + var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm(); + llamaSettings.setCustomLlamaModelPath(llamaModelPreferencesForm.getCustomLlamaModelPath()); + llamaSettings.setHuggingFaceModel(llamaModelPreferencesForm.getSelectedModel()); + llamaSettings.setUseCustomModel(llamaModelPreferencesForm.isUseCustomLlamaModel()); + llamaSettings.setPromptTemplate(llamaModelPreferencesForm.getPromptTemplate()); + llamaSettings.setServerPort(serviceSelectionForm.getLlamaServerPort()); + llamaSettings.setContextSize(serviceSelectionForm.getContextSize()); openAISettings.apply(serviceSelectionForm); azureSettings.apply(serviceSelectionForm); @@ -93,7 +121,7 @@ public class SettingsConfigurable implements Configurable { resetActiveTab(); if (serviceChanged) { TelemetryAction.SETTINGS_CHANGED.createActionMessage() - .property("service", getServiceCode(serviceSelectionForm)) + .property("service", getServiceCode()) .send(); } } @@ -104,15 +132,32 @@ public class SettingsConfigurable implements Configurable { var settings = SettingsState.getInstance(); var openAISettings = OpenAISettingsState.getInstance(); var azureSettings = AzureSettingsState.getInstance(); + var llamaSettings = LlamaSettingsState.getInstance(); var serviceSelectionForm = settingsComponent.getServiceSelectionForm(); - settingsComponent.setEmail(settings.getEmail()); + // settingsComponent.setEmail(settings.getEmail()); settingsComponent.setDisplayName(settings.getDisplayName()); - serviceSelectionForm.setOpenAIServiceSelected(settings.isUseOpenAIService()); - serviceSelectionForm.setAzureServiceSelected(settings.isUseAzureService()); - serviceSelectionForm.setYouServiceSelected(settings.isUseYouService()); - + // TODO + if (settings.isUseOpenAIService()) { + settingsComponent.setSelectedService(OPENAI); + } + if (settings.isUseAzureService()) { + settingsComponent.setSelectedService(ServiceType.AZURE); + } + if (settings.isUseYouService()) { + settingsComponent.setSelectedService(ServiceType.YOU); + } + if (settings.isUseLlamaService()) { + settingsComponent.setSelectedService(ServiceType.LLAMA_CPP); + } + var llamaModelPreferencesForm = serviceSelectionForm.getLlamaModelPreferencesForm(); + llamaModelPreferencesForm.setSelectedModel(llamaSettings.getHuggingFaceModel()); + llamaModelPreferencesForm.setCustomLlamaModelPath(llamaSettings.getCustomLlamaModelPath()); + llamaModelPreferencesForm.setUseCustomLlamaModel(llamaSettings.isUseCustomModel()); + llamaModelPreferencesForm.setPromptTemplate(llamaSettings.getPromptTemplate()); + serviceSelectionForm.setLlamaServerPort(llamaSettings.getServerPort()); + serviceSelectionForm.setContextSize(llamaSettings.getContextSize()); openAISettings.reset(serviceSelectionForm); azureSettings.reset(serviceSelectionForm); @@ -128,12 +173,11 @@ public class SettingsConfigurable implements Configurable { settingsComponent = null; } - private boolean isServiceChanged( - ServiceSelectionForm serviceSelectionForm, - SettingsState settings) { - return serviceSelectionForm.isOpenAIServiceSelected() != settings.isUseOpenAIService() || - serviceSelectionForm.isAzureServiceSelected() != settings.isUseAzureService() || - serviceSelectionForm.isYouServiceSelected() != settings.isUseYouService(); + private boolean isServiceChanged(SettingsState settings) { + return (settingsComponent.getSelectedService() == OPENAI) != settings.isUseOpenAIService() || + (settingsComponent.getSelectedService() == AZURE) != settings.isUseAzureService() || + (settingsComponent.getSelectedService() == YOU) != settings.isUseYouService() || + (settingsComponent.getSelectedService() == LLAMA_CPP) != settings.isUseLlamaService(); } private void resetActiveTab() { @@ -146,16 +190,19 @@ public class SettingsConfigurable implements Configurable { project.getService(StandardChatToolWindowContentManager.class).resetActiveTab(); } - private String getServiceCode(ServiceSelectionForm serviceSelectionForm) { - if (serviceSelectionForm.isOpenAIServiceSelected()) { + private String getServiceCode() { + if (settingsComponent.getSelectedService() == OPENAI) { return "openai"; } - if (serviceSelectionForm.isAzureServiceSelected()) { + if (settingsComponent.getSelectedService() == AZURE) { return "azure"; } - if (serviceSelectionForm.isYouServiceSelected()) { + if (settingsComponent.getSelectedService() == YOU) { return "you"; } + if (settingsComponent.getSelectedService() == LLAMA_CPP) { + return "llama.cpp"; + } return null; } } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/advanced/AdvancedSettingsComponent.java b/src/main/java/ee/carlrobert/codegpt/settings/advanced/AdvancedSettingsComponent.java index 92151c05..51d28fab 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/advanced/AdvancedSettingsComponent.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/advanced/AdvancedSettingsComponent.java @@ -38,7 +38,8 @@ public class AdvancedSettingsComponent { proxyTypeComboBox.setSelectedItem(advancedSettings.getProxyType()); proxyHostField = new JBTextField(advancedSettings.getProxyHost(), 20); proxyPortField = new PortField(); - proxyAuthCheckbox = new JBCheckBox(CodeGPTBundle.get("advancedSettingsConfigurable.section.proxy.authCheckBoxField.label")); + proxyAuthCheckbox = new JBCheckBox(CodeGPTBundle.get( + "advancedSettingsConfigurable.proxy.authCheckBoxField.label")); proxyAuthUsername = new JBTextField(20); proxyAuthUsername.setEnabled(advancedSettings.isProxyAuthSelected()); proxyAuthPassword = new JBPasswordField(); @@ -52,10 +53,11 @@ public class AdvancedSettingsComponent { readTimeoutField = new PortField(advancedSettings.getReadTimeout()); mainPanel = FormBuilder.createFormBuilder() - .addComponent(new TitledSeparator(CodeGPTBundle.get("advancedSettingsConfigurable.section.proxy.title"))) + .addComponent(new TitledSeparator(CodeGPTBundle.get( + "advancedSettingsConfigurable.proxy.title"))) .addComponent(createProxySettingsForm()) .addVerticalGap(4) - .addComponent(new TitledSeparator("Connection Settings")) + .addComponent(new TitledSeparator(CodeGPTBundle.get("advancedSettingsConfigurable.connectionSettings.title"))) .addComponent(createConnectionSettingsForm()) .addComponentFillVertically(new JPanel(), 0) .getPanel(); @@ -63,8 +65,8 @@ public class AdvancedSettingsComponent { private JPanel createConnectionSettingsForm() { var panel = FormBuilder.createFormBuilder() - .addLabeledComponent("Connection timeout (s):", connectionTimeoutField) - .addLabeledComponent("Read timeout (s):", readTimeoutField) + .addLabeledComponent(CodeGPTBundle.get("advancedSettingsConfigurable.connectionSettings.connectionTimeout.label"), connectionTimeoutField) + .addLabeledComponent(CodeGPTBundle.get("advancedSettingsConfigurable.connectionSettings.readTimeout.label"), readTimeoutField) .getPanel(); panel.setBorder(JBUI.Borders.emptyLeft(16)); return panel; @@ -145,15 +147,15 @@ public class AdvancedSettingsComponent { var proxyTypePanel = SwingUtils.createPanel( proxyTypeComboBox, - CodeGPTBundle.get("advancedSettingsConfigurable.section.proxy.typeComboBoxField.label"), + CodeGPTBundle.get("advancedSettingsConfigurable.proxy.typeComboBoxField.label"), false); var proxyHostPanel = SwingUtils.createPanel( proxyHostField, - CodeGPTBundle.get("advancedSettingsConfigurable.section.proxy.hostField.label"), + CodeGPTBundle.get("advancedSettingsConfigurable.proxy.hostField.label"), false); var proxyPortPanel = SwingUtils.createPanel( proxyPortField, - CodeGPTBundle.get("advancedSettingsConfigurable.section.proxy.portField.label"), + CodeGPTBundle.get("advancedSettingsConfigurable.proxy.portField.label"), false); SwingUtils.setEqualLabelWidths(proxyTypePanel, proxyHostPanel); SwingUtils.setEqualLabelWidths(proxyPortPanel, proxyHostPanel); @@ -166,10 +168,10 @@ public class AdvancedSettingsComponent { .createPanel()); var proxyUsernamePanel = SwingUtils.createPanel(proxyAuthUsername, - CodeGPTBundle.get("advancedSettingsConfigurable.section.proxy.usernameField.label"), + CodeGPTBundle.get("advancedSettingsConfigurable.proxy.usernameField.label"), false); var proxyPasswordPanel = SwingUtils.createPanel(proxyAuthPassword, - CodeGPTBundle.get("advancedSettingsConfigurable.section.proxy.passwordField.label"), + CodeGPTBundle.get("advancedSettingsConfigurable.proxy.passwordField.label"), false); SwingUtils.setEqualLabelWidths(proxyPasswordPanel, proxyUsernamePanel); diff --git a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java index 46e97e80..02e2b1d0 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/configuration/ConfigurationComponent.java @@ -180,10 +180,10 @@ public class ConfigurationComponent { try { var value = Double.parseDouble(valueText); if (value > 1.0 || value < 0.0) { - return new ValidationInfo("Value must be between 0 and 1.", component); + return new ValidationInfo(CodeGPTBundle.get("validation.error.mustBeBetweenZeroAndOne"), component); } } catch (NumberFormatException e) { - return new ValidationInfo("Value must be number.", component); + return new ValidationInfo(CodeGPTBundle.get("validation.error.mustBeNumber"), component); } return null; diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/DownloadModelAction.java b/src/main/java/ee/carlrobert/codegpt/settings/service/DownloadModelAction.java new file mode 100644 index 00000000..dd3a5ca0 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/DownloadModelAction.java @@ -0,0 +1,98 @@ +package ee.carlrobert.codegpt.settings.service; + +import static java.lang.String.format; + +import com.intellij.openapi.actionSystem.AnAction; +import com.intellij.openapi.actionSystem.AnActionEvent; +import com.intellij.openapi.diagnostic.Logger; +import com.intellij.openapi.progress.ProgressIndicator; +import com.intellij.openapi.progress.ProgressManager; +import com.intellij.openapi.progress.Task; +import com.intellij.openapi.project.Project; +import ee.carlrobert.codegpt.CodeGPTBundle; +import ee.carlrobert.codegpt.completions.HuggingFaceModel; +import ee.carlrobert.codegpt.util.DownloadingUtils; +import ee.carlrobert.codegpt.util.file.FileUtils; +import java.io.IOException; +import java.net.URL; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import javax.swing.DefaultComboBoxModel; +import org.jetbrains.annotations.NotNull; + +public class DownloadModelAction extends AnAction { + + private static final Logger LOG = Logger.getInstance(DownloadModelAction.class); + + private final Consumer onDownload; + private final Runnable onDownloaded; + private final Consumer onFailed; + private final Consumer onUpdateProgress; + private final DefaultComboBoxModel comboBoxModel; + + public DownloadModelAction( + Consumer onDownload, + Runnable onDownloaded, + Consumer onFailed, + Consumer onUpdateProgress, + DefaultComboBoxModel comboBoxModel) { + this.onDownload = onDownload; + this.onDownloaded = onDownloaded; + this.onFailed = onFailed; + this.onUpdateProgress = onUpdateProgress; + this.comboBoxModel = comboBoxModel; + } + + @Override + public void actionPerformed(@NotNull AnActionEvent e) { + ProgressManager.getInstance().run(new DownloadBackgroundTask(e.getProject())); + } + + class DownloadBackgroundTask extends Task.Backgroundable { + + DownloadBackgroundTask(Project project) { + super(project, CodeGPTBundle.get("settingsConfigurable.service.llama.progress.downloadingModel.title"), true); + } + + @Override + public void run(@NotNull ProgressIndicator indicator) { + var model = (HuggingFaceModel) comboBoxModel.getSelectedItem(); + URL url = model.getFileURL(); + ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor(); + ScheduledFuture progressUpdateScheduler = null; + + try { + onDownload.accept(indicator); + + indicator.setIndeterminate(false); + indicator.setText(format(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.downloadingModelIndicator.text"), model.getFileName())); + + long fileSize = url.openConnection().getContentLengthLong(); + long[] bytesRead = {0}; + long startTime = System.currentTimeMillis(); + + progressUpdateScheduler = executorService.scheduleAtFixedRate(() -> + onUpdateProgress.accept( + DownloadingUtils.getFormattedDownloadProgress(startTime, fileSize, bytesRead[0])), + 0, 1, TimeUnit.SECONDS); + FileUtils.copyFileWithProgress(model.getFileName(), url, bytesRead, fileSize, indicator); + } catch (IOException ex) { + LOG.error("Unable to open connection", ex); + onFailed.accept(ex); + } finally { + if (progressUpdateScheduler != null) { + progressUpdateScheduler.cancel(true); + } + executorService.shutdown(); + } + } + + @Override + public void onSuccess() { + onDownloaded.run(); + } + } +} \ No newline at end of file diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/LlamaModelPreferencesForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/LlamaModelPreferencesForm.java new file mode 100644 index 00000000..2d8fde21 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/LlamaModelPreferencesForm.java @@ -0,0 +1,442 @@ +package ee.carlrobert.codegpt.settings.service; + +import static java.util.stream.Collectors.toList; + +import com.intellij.icons.AllIcons.Actions; +import com.intellij.icons.AllIcons.General; +import com.intellij.ide.HelpTooltip; +import com.intellij.openapi.actionSystem.AnAction; +import com.intellij.openapi.actionSystem.AnActionEvent; +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.fileChooser.FileChooserDescriptorFactory; +import com.intellij.openapi.progress.ProgressIndicator; +import com.intellij.openapi.ui.ComboBox; +import com.intellij.openapi.ui.TextBrowseFolderListener; +import com.intellij.openapi.ui.TextFieldWithBrowseButton; +import com.intellij.openapi.ui.panel.ComponentPanelBuilder; +import com.intellij.openapi.util.io.FileUtil; +import com.intellij.ui.EnumComboBoxModel; +import com.intellij.ui.components.AnActionLink; +import com.intellij.ui.components.JBCheckBox; +import com.intellij.ui.components.JBLabel; +import com.intellij.util.ui.FormBuilder; +import com.intellij.util.ui.JBUI; +import ee.carlrobert.codegpt.CodeGPTBundle; +import ee.carlrobert.codegpt.CodeGPTPlugin; +import ee.carlrobert.codegpt.completions.HuggingFaceModel; +import ee.carlrobert.codegpt.completions.llama.LlamaModel; +import ee.carlrobert.codegpt.completions.llama.LlamaServerAgent; +import ee.carlrobert.codegpt.completions.llama.PromptTemplate; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; +import java.awt.BorderLayout; +import java.awt.Dimension; +import java.awt.FlowLayout; +import java.io.File; +import java.util.Map; +import javax.swing.Box; +import javax.swing.DefaultComboBoxModel; +import javax.swing.JPanel; +import javax.swing.SwingUtilities; +import org.jetbrains.annotations.NotNull; + +public class LlamaModelPreferencesForm { + + private static final Map> modelDetailsMap = Map.of( + 7, Map.of( + 3, new ModelDetails(3.30, 5.80), + 4, new ModelDetails(4.08, 6.58), + 5, new ModelDetails(4.78, 7.28)), + 13, Map.of( + 3, new ModelDetails(6.34, 8.84), + 4, new ModelDetails(7.87, 10.37), + 5, new ModelDetails(9.23, 11.73)), + 34, Map.of( + 3, new ModelDetails(16.28, 18.78), + 4, new ModelDetails(20.22, 22.72), + 5, new ModelDetails(23.84, 26.34))); + + private final TextFieldWithBrowseButton customModelPathBrowserButton; + private final ComboBox modelComboBox; + private final ComboBox modelSizeComboBox; + private final ComboBox huggingFaceModelComboBox; + private final ComboBox promptTemplateComboBox; + private final JBLabel modelExistsIcon; + private final DefaultComboBoxModel huggingFaceComboBoxModel; + private final JBCheckBox useCustomModelCheckBox; + private final JBLabel helpIcon; + private final JPanel downloadModelActionLinkWrapper; + private final JBLabel progressLabel; + private final JBLabel modelDetailsLabel; + + public TextFieldWithBrowseButton getCustomModelPathBrowserButton() { + return customModelPathBrowserButton; + } + + public ComboBox getHuggingFaceModelComboBox() { + return huggingFaceModelComboBox; + } + + public LlamaModelPreferencesForm() { + var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class); + var llamaSettings = LlamaSettingsState.getInstance(); + customModelPathBrowserButton = createCustomModelPathBrowseButton( + llamaSettings.isUseCustomModel() && !llamaServerAgent.isServerRunning()); + customModelPathBrowserButton.setText(llamaSettings.getCustomLlamaModelPath()); + progressLabel = new JBLabel(""); + progressLabel.setBorder(JBUI.Borders.emptyLeft(2)); + progressLabel.setFont(JBUI.Fonts.smallFont()); + modelExistsIcon = new JBLabel(Actions.Commit); + modelExistsIcon.setVisible(isModelExists(llamaSettings.getHuggingFaceModel())); + helpIcon = new JBLabel(General.ContextHelp); + huggingFaceComboBoxModel = new DefaultComboBoxModel<>(); + var llm = llamaSettings.getHuggingFaceModel(); + var llamaModel = LlamaModel.findByHuggingFaceModel(llm); + + var selectableModels = llamaModel.getHuggingFaceModels().stream() + .filter(model -> model.getParameterSize() == llm.getParameterSize()) + .collect(toList()); + huggingFaceComboBoxModel.addAll(selectableModels); + huggingFaceComboBoxModel.setSelectedItem(selectableModels.get(0)); + downloadModelActionLinkWrapper = new JPanel(new BorderLayout()); + downloadModelActionLinkWrapper.setBorder(JBUI.Borders.emptyLeft(2)); + downloadModelActionLinkWrapper.add( + createDownloadModelLink( + progressLabel, + downloadModelActionLinkWrapper, + huggingFaceComboBoxModel), + BorderLayout.WEST); + modelDetailsLabel = new JBLabel(); + huggingFaceModelComboBox = createHuggingFaceComboBox( + huggingFaceComboBoxModel, + modelExistsIcon, + modelDetailsLabel, + downloadModelActionLinkWrapper); + huggingFaceModelComboBox.setEnabled(!llamaServerAgent.isServerRunning()); + var modelSizeComboBoxModel = new DefaultComboBoxModel(); + var initialModelSizes = llamaModel.getSortedUniqueModelSizes().stream() + .map(ModelSize::new) + .collect(toList()); + modelSizeComboBoxModel.addAll(initialModelSizes); + modelSizeComboBoxModel.setSelectedItem(initialModelSizes.get(0)); + var modelComboBoxModel = new EnumComboBoxModel<>(LlamaModel.class); + modelComboBox = createModelComboBox(modelComboBoxModel, llamaModel, modelSizeComboBoxModel); + modelComboBox.setEnabled(!llamaServerAgent.isServerRunning()); + modelSizeComboBox = createModelSizeComboBox( + modelComboBoxModel, + modelSizeComboBoxModel, + huggingFaceComboBoxModel); + modelSizeComboBox.setEnabled(initialModelSizes.size() > 1 && !llamaServerAgent.isServerRunning()); + promptTemplateComboBox = new ComboBox<>(new EnumComboBoxModel<>(PromptTemplate.class)); + promptTemplateComboBox.setSelectedItem(llamaSettings.getPromptTemplate()); + promptTemplateComboBox.setEnabled( + llamaSettings.isUseCustomModel() && !llamaServerAgent.isServerRunning()); + promptTemplateComboBox.setPreferredSize(modelComboBox.getPreferredSize()); + useCustomModelCheckBox = new JBCheckBox(CodeGPTBundle.get( + "settingsConfigurable.service.llama.useCustomModel.label"), llamaSettings.isUseCustomModel()); + useCustomModelCheckBox.setEnabled(!llamaServerAgent.isServerRunning()); + useCustomModelCheckBox.addChangeListener(e -> { + var selected = ((JBCheckBox) e.getSource()).isSelected(); + customModelPathBrowserButton.setEnabled(selected && !llamaServerAgent.isServerRunning()); + promptTemplateComboBox.setEnabled(selected && !llamaServerAgent.isServerRunning()); + modelComboBox.setEnabled(!selected); + modelSizeComboBox.setEnabled((!selected)); + huggingFaceModelComboBox.setEnabled((!selected)); + }); + } + + public JPanel getForm() { + var customModelHelpText = ComponentPanelBuilder.createCommentComponent( + CodeGPTBundle.get("settingsConfigurable.service.llama.customModelPath.comment"), + true); + customModelHelpText.setBorder(JBUI.Borders.empty(0, 4)); + var quantizationHelpText = ComponentPanelBuilder.createCommentComponent( + CodeGPTBundle.get("settingsConfigurable.service.llama.quantization.comment"), + true); + quantizationHelpText.setBorder(JBUI.Borders.empty(0, 4)); + + var modelComboBoxWrapper = new JPanel(new FlowLayout(FlowLayout.LEADING, 0, 0)); + modelComboBoxWrapper.add(modelComboBox); + modelComboBoxWrapper.add(Box.createHorizontalStrut(8)); + modelComboBoxWrapper.add(helpIcon); + modelComboBoxWrapper.add(Box.createHorizontalStrut(4)); + modelComboBoxWrapper.add(modelExistsIcon); + + var huggingFaceModelComboBoxWrapper = new JPanel(new FlowLayout(FlowLayout.LEADING, 0, 0)); + huggingFaceModelComboBoxWrapper.add(huggingFaceModelComboBox); + huggingFaceModelComboBoxWrapper.add(Box.createHorizontalStrut(8)); + huggingFaceModelComboBoxWrapper.add(modelDetailsLabel); + + return FormBuilder.createFormBuilder() + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.shared.model.label"), modelComboBoxWrapper) + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.llama.modelSize.label"), modelSizeComboBox) + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.llama.quantization.label"), huggingFaceModelComboBoxWrapper) + .addComponentToRightColumn(quantizationHelpText) + .addComponentToRightColumn(downloadModelActionLinkWrapper) + .addComponentToRightColumn(progressLabel) + .addVerticalGap(8) + .addComponent(useCustomModelCheckBox) + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.llama.promptTemplate.label"), promptTemplateComboBox) + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.llama.customModelPath.label"), customModelPathBrowserButton) + .addComponentToRightColumn(customModelHelpText) + .addVerticalGap(4) + .getPanel(); + } + + public void enableFields(boolean enabled) { + modelComboBox.setEnabled(enabled); + modelSizeComboBox.setEnabled(enabled); + huggingFaceModelComboBox.setEnabled(enabled); + useCustomModelCheckBox.setEnabled(enabled); + promptTemplateComboBox.setEnabled(enabled && useCustomModelCheckBox.isSelected()); + customModelPathBrowserButton.setEnabled(enabled && useCustomModelCheckBox.isSelected()); + } + + private static class ModelDetails { + + double fileSize; + double maxRAMRequired; + + public ModelDetails(double fileSize, double maxRAMRequired) { + this.fileSize = fileSize; + this.maxRAMRequired = maxRAMRequired; + } + } + + private String getHuggingFaceModelDetailsHtml(HuggingFaceModel model) { + int parameterSize = model.getParameterSize(); + int quantization = model.getQuantization(); + + if (!modelDetailsMap.containsKey(parameterSize)) { + return ""; + } + + ModelDetails details = modelDetailsMap.get(parameterSize).get(quantization); + if (details == null) { + return ""; + } + + return String.format("" + + "

File Size: %.2f GB

" + + "

Max RAM Required: %.2f GB

" + + "", details.fileSize, details.maxRAMRequired); + } + + public void setSelectedModel(HuggingFaceModel model) { + huggingFaceComboBoxModel.setSelectedItem(model); + } + + public HuggingFaceModel getSelectedModel() { + return (HuggingFaceModel) huggingFaceComboBoxModel.getSelectedItem(); + } + + public void setCustomLlamaModelPath(String modelPath) { + customModelPathBrowserButton.setText(modelPath); + } + + public String getCustomLlamaModelPath() { + return customModelPathBrowserButton.getText(); + } + + public void setUseCustomLlamaModel(boolean useCustomLlamaModel) { + useCustomModelCheckBox.setSelected(useCustomLlamaModel); + } + + public boolean isUseCustomLlamaModel() { + return useCustomModelCheckBox.isSelected(); + } + + public void setPromptTemplate(PromptTemplate promptTemplate) { + promptTemplateComboBox.setSelectedItem(promptTemplate); + } + + public PromptTemplate getPromptTemplate() { + return promptTemplateComboBox.getItem(); + } + + private ComboBox createModelComboBox( + EnumComboBoxModel llamaModelEnumComboBoxModel, + LlamaModel llamaModel, + DefaultComboBoxModel modelSizeComboBoxModel) { + var comboBox = new ComboBox<>(llamaModelEnumComboBoxModel); + comboBox.setPreferredSize(new Dimension(280, comboBox.getPreferredSize().height)); + comboBox.setSelectedItem(llamaModel); + comboBox.addItemListener(e -> { + var selectedModel = (LlamaModel) e.getItem(); + var modelSizes = selectedModel.getSortedUniqueModelSizes().stream() + .map(ModelSize::new) + .collect(toList()); + + modelSizeComboBoxModel.removeAllElements(); + modelSizeComboBoxModel.addAll(modelSizes); + modelSizeComboBoxModel.setSelectedItem(modelSizes.get(0)); + modelSizeComboBox.setEnabled(modelSizes.size() > 1); + + var huggingFaceModels = selectedModel.getHuggingFaceModels().stream() + .filter(model -> { + var size = ((ModelSize) modelSizeComboBoxModel.getSelectedItem()).getSize(); + return size == model.getParameterSize(); + }) + .collect(toList()); + + huggingFaceComboBoxModel.removeAllElements(); + huggingFaceComboBoxModel.addAll(huggingFaceModels); + huggingFaceComboBoxModel.setSelectedItem(huggingFaceModels.get(0)); + }); + return comboBox; + } + + private ComboBox createModelSizeComboBox( + EnumComboBoxModel llamaModelComboBoxModel, + DefaultComboBoxModel modelSizeComboBoxModel, + DefaultComboBoxModel huggingFaceComboBoxModel) { + var comboBox = new ComboBox<>(modelSizeComboBoxModel); + comboBox.setPreferredSize(modelComboBox.getPreferredSize()); + comboBox.setSelectedItem(modelSizeComboBoxModel.getSelectedItem()); + comboBox.addItemListener(e -> { + var selectedModel = llamaModelComboBoxModel.getSelectedItem(); + var models = selectedModel.getHuggingFaceModels().stream() + .filter(model -> { + var selectedModelSize = (ModelSize) modelSizeComboBoxModel.getSelectedItem(); + return selectedModelSize != null && + selectedModelSize.getSize() == model.getParameterSize(); + }) + .collect(toList()); + if (!models.isEmpty()) { + huggingFaceComboBoxModel.removeAllElements(); + huggingFaceComboBoxModel.addAll(models); + huggingFaceComboBoxModel.setSelectedItem(models.get(0)); + } + }); + return comboBox; + } + + private ComboBox createHuggingFaceComboBox( + DefaultComboBoxModel huggingFaceComboBoxModel, + JBLabel modelExistsIcon, + JBLabel modelDetailsLabel, + JPanel downloadModelActionLinkWrapper) { + var comboBox = new ComboBox<>(huggingFaceComboBoxModel); + comboBox.addItemListener(e -> { + var selectedModel = (HuggingFaceModel) e.getItem(); + var modelExists = isModelExists(selectedModel); + + updateModelHelpTooltip(selectedModel); + modelDetailsLabel.setText(getHuggingFaceModelDetailsHtml(selectedModel)); + modelExistsIcon.setVisible(modelExists); + downloadModelActionLinkWrapper.setVisible(!modelExists); + }); + return comboBox; + } + + private TextFieldWithBrowseButton createCustomModelPathBrowseButton(boolean enabled) { + var browseButton = new TextFieldWithBrowseButton(); + browseButton.setEnabled(enabled); + + var fileChooserDescriptor = FileChooserDescriptorFactory.createSingleFileDescriptor("gguf"); + fileChooserDescriptor.setForcedToUseIdeaFileChooser(true); + fileChooserDescriptor.setHideIgnored(false); + browseButton.addBrowseFolderListener(new TextBrowseFolderListener(fileChooserDescriptor)); + return browseButton; + } + + private boolean isModelExists(HuggingFaceModel model) { + return FileUtil.exists( + CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName()); + } + + private AnActionLink createCancelDownloadLink( + JBLabel progressLabel, + JPanel actionLinkWrapper, + DefaultComboBoxModel huggingFaceComboBoxModel, + ProgressIndicator progressIndicator) { + return new AnActionLink(CodeGPTBundle.get("settingsConfigurable.service.llama.cancelDownloadLink.label"), new AnAction() { + @Override + public void actionPerformed(@NotNull AnActionEvent e) { + SwingUtilities.invokeLater(() -> { + configureFieldsForDownloading(false); + updateActionLink( + actionLinkWrapper, + createDownloadModelLink(progressLabel, actionLinkWrapper, huggingFaceComboBoxModel)); + progressIndicator.cancel(); + }); + } + }); + } + + private void updateActionLink(JPanel actionLinkWrapper, AnActionLink actionLink) { + actionLinkWrapper.removeAll(); + actionLinkWrapper.add(actionLink, BorderLayout.WEST); + actionLinkWrapper.revalidate(); + actionLinkWrapper.repaint(); + } + + void configureFieldsForDownloading(boolean downloading) { + progressLabel.setText(""); + progressLabel.setVisible(downloading); + modelComboBox.setEnabled(!downloading); + modelSizeComboBox.setEnabled(!downloading); + huggingFaceModelComboBox.setEnabled(!downloading); + modelExistsIcon.setVisible(!downloading); + } + + private AnActionLink createDownloadModelLink( + JBLabel progressLabel, + JPanel actionLinkWrapper, + DefaultComboBoxModel huggingFaceComboBoxModel) { + return new AnActionLink(CodeGPTBundle.get("settingsConfigurable.service.llama.downloadModelLink.label"), new DownloadModelAction( + progressIndicator -> { + SwingUtilities.invokeLater(() -> { + configureFieldsForDownloading(true); + updateActionLink( + actionLinkWrapper, + createCancelDownloadLink( + progressLabel, + actionLinkWrapper, + huggingFaceComboBoxModel, + progressIndicator)); + }); + }, + () -> SwingUtilities.invokeLater(() -> { + configureFieldsForDownloading(false); + updateActionLink( + actionLinkWrapper, + createDownloadModelLink(progressLabel, actionLinkWrapper, huggingFaceComboBoxModel)); + actionLinkWrapper.setVisible(false); + LlamaSettingsState.getInstance() + .setHuggingFaceModel((HuggingFaceModel) huggingFaceComboBoxModel.getSelectedItem()); + }), + (error) -> { + throw new RuntimeException(error); + }, + (text) -> SwingUtilities.invokeLater(() -> progressLabel.setText(text)), + huggingFaceComboBoxModel), "unknown"); + } + + private void updateModelHelpTooltip(HuggingFaceModel model) { + helpIcon.setToolTipText(null); + var llamaModel = LlamaModel.findByHuggingFaceModel(model); + new HelpTooltip() + .setTitle(llamaModel.getLabel()) + .setDescription("

" + llamaModel.getDescription() + "

") + .setBrowserLink(CodeGPTBundle.get("settingsConfigurable.service.llama.linkToModel.label"), model.getHuggingFaceURL()) + .installOn(helpIcon); + } + + static class ModelSize { + + private final int size; + + ModelSize(int size) { + this.size = size; + } + + int getSize() { + return size; + } + + @Override + public String toString() { + return size + "B"; + } + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/LlamaServiceSelectionForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/LlamaServiceSelectionForm.java new file mode 100644 index 00000000..dfac9d85 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/LlamaServiceSelectionForm.java @@ -0,0 +1,163 @@ +package ee.carlrobert.codegpt.settings.service; + +import com.intellij.icons.AllIcons.Actions; +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.ui.MessageType; +import com.intellij.openapi.ui.panel.ComponentPanelBuilder; +import com.intellij.openapi.util.io.FileUtil; +import com.intellij.ui.PortField; +import com.intellij.ui.TitledSeparator; +import com.intellij.ui.components.JBLabel; +import com.intellij.ui.components.fields.IntegerField; +import com.intellij.util.ui.FormBuilder; +import com.intellij.util.ui.JBUI; +import ee.carlrobert.codegpt.CodeGPTBundle; +import ee.carlrobert.codegpt.CodeGPTPlugin; +import ee.carlrobert.codegpt.completions.HuggingFaceModel; +import ee.carlrobert.codegpt.completions.llama.LlamaServerAgent; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; +import ee.carlrobert.codegpt.util.OverlayUtils; +import java.awt.BorderLayout; +import java.io.File; +import javax.swing.JButton; +import javax.swing.JComponent; +import javax.swing.JPanel; +import javax.swing.SwingConstants; + +public class LlamaServiceSelectionForm extends JPanel { + + private final LlamaModelPreferencesForm llamaModelPreferencesForm; + private final PortField portField; + private final IntegerField maxTokensField; + + public LlamaServiceSelectionForm() { + var llamaServerAgent = ApplicationManager.getApplication().getService(LlamaServerAgent.class); + var serverRunning = llamaServerAgent.isServerRunning(); + portField = new PortField(LlamaSettingsState.getInstance().getServerPort()); + portField.setEnabled(!serverRunning); + + var serverProgressPanel = new ServerProgressPanel(); + llamaModelPreferencesForm = new LlamaModelPreferencesForm(); + + maxTokensField = new IntegerField("max_tokens", 256, 4096); + maxTokensField.setColumns(12); + maxTokensField.setValue(2048); + maxTokensField.setEnabled(!serverRunning); + + var serverButton = new JButton(); + serverButton.setText(serverRunning ? + CodeGPTBundle.get("settingsConfigurable.service.llama.stopServer.label") : + CodeGPTBundle.get("settingsConfigurable.service.llama.startServer.label")); + serverButton.setIcon(serverRunning ? Actions.Suspend : Actions.Execute); + serverButton.addActionListener(event -> { + if (llamaModelPreferencesForm.isUseCustomLlamaModel()) { + var customModelPath = llamaModelPreferencesForm.getCustomLlamaModelPath(); + if (customModelPath == null || customModelPath.isEmpty()) { + OverlayUtils.showBalloon( + CodeGPTBundle.get("validation.error.fieldRequired"), + MessageType.ERROR, + llamaModelPreferencesForm.getCustomModelPathBrowserButton()); + return; + } + } else { + if (!isModelExists(llamaModelPreferencesForm.getSelectedModel())) { + OverlayUtils.showBalloon( + CodeGPTBundle.get("settingsConfigurable.service.llama.overlay.modelNotDownloaded.text"), + MessageType.ERROR, + llamaModelPreferencesForm.getHuggingFaceModelComboBox()); + return; + } + } + + if (llamaServerAgent.isServerRunning()) { + setFormEnabled(true); + serverButton.setText(CodeGPTBundle.get("settingsConfigurable.service.llama.startServer.label")); + serverButton.setIcon(Actions.Execute); + serverProgressPanel.updateText(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.stoppingServer")); + llamaServerAgent.stopAgent(); + } else { + setFormEnabled(false); + serverButton.setText(CodeGPTBundle.get("settingsConfigurable.service.llama.stopServer.label")); + serverButton.setIcon(Actions.Suspend); + serverProgressPanel.startProgress(CodeGPTBundle.get("settingsConfigurable.service.llama.progress.startingServer")); + + // TODO: Move to LlamaModelPreferencesForm + var modelPath = llamaModelPreferencesForm.isUseCustomLlamaModel() ? + llamaModelPreferencesForm.getCustomLlamaModelPath() : + CodeGPTPlugin.getLlamaModelsPath() + + File.separator + + llamaModelPreferencesForm.getSelectedModel().getFileName(); + llamaServerAgent.startAgent( + modelPath, + maxTokensField.getValue(), + portField.getNumber(), + serverProgressPanel, + () -> { + setFormEnabled(false); + serverProgressPanel.displayComponent(new JBLabel( + "Server running", + Actions.Commit, + SwingConstants.LEADING)); + }); + } + }); + + var contextSizeHelpText = ComponentPanelBuilder.createCommentComponent( + CodeGPTBundle.get("settingsConfigurable.service.llama.contextSize.comment"), + true); + contextSizeHelpText.setBorder(JBUI.Borders.empty(0, 4)); + + setLayout(new BorderLayout()); + add(FormBuilder.createFormBuilder() + .addComponent(new TitledSeparator(CodeGPTBundle.get("settingsConfigurable.service.llama.modelPreferences.title"))) + .addComponent(withEmptyLeftBorder(llamaModelPreferencesForm.getForm())) + .addComponent(new TitledSeparator(CodeGPTBundle.get("settingsConfigurable.service.llama.serverPreferences.title"))) + .addComponent(withEmptyLeftBorder(FormBuilder.createFormBuilder() + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.llama.contextSize.label"), maxTokensField) + .addComponentToRightColumn(contextSizeHelpText) + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.llama.port.label"), JBUI.Panels.simplePanel() + .addToLeft(portField) + .addToRight(serverButton)) + .getPanel())) + .addVerticalGap(4) + .addComponent(withEmptyLeftBorder(serverProgressPanel)) + .addComponentFillVertically(new JPanel(), 0) + .getPanel()); + } + + private boolean isModelExists(HuggingFaceModel model) { + return FileUtil.exists( + CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName()); + } + + private void setFormEnabled(boolean enabled) { + llamaModelPreferencesForm.enableFields(enabled); + portField.setEnabled(enabled); + maxTokensField.setEnabled(enabled); + } + + public void setServerPort(int serverPort) { + portField.setNumber(serverPort); + } + + public int getServerPort() { + return portField.getNumber(); + } + + public LlamaModelPreferencesForm getLlamaModelPreferencesForm() { + return llamaModelPreferencesForm; + } + + private JComponent withEmptyLeftBorder(JComponent component) { + component.setBorder(JBUI.Borders.emptyLeft(16)); + return component; + } + + public int getContextSize() { + return maxTokensField.getValue(); + } + + public void setContextSize(int contextSize) { + maxTokensField.setValue(contextSize); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/ServerProgressPanel.java b/src/main/java/ee/carlrobert/codegpt/settings/service/ServerProgressPanel.java new file mode 100644 index 00000000..8e2ad0ad --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/ServerProgressPanel.java @@ -0,0 +1,37 @@ +package ee.carlrobert.codegpt.settings.service; + +import com.intellij.ui.components.JBLabel; +import com.intellij.util.ui.AsyncProcessIcon; +import java.awt.FlowLayout; +import javax.swing.Box; +import javax.swing.JComponent; +import javax.swing.JPanel; + +public class ServerProgressPanel extends JPanel { + + private final JBLabel label = new JBLabel(); + + public ServerProgressPanel() { + super(new FlowLayout(FlowLayout.LEADING, 0, 0)); + setVisible(false); + add(new AsyncProcessIcon("sign_in_spinner")); + add(Box.createHorizontalStrut(4)); + add(label); + } + + public void startProgress(String text) { + setVisible(true); + updateText(text); + } + + public void updateText(String text) { + label.setText(text); + } + + public void displayComponent(JComponent component) { + removeAll(); + add(component); + revalidate(); + repaint(); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/ServiceSelectionForm.java b/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceSelectionForm.java similarity index 65% rename from src/main/java/ee/carlrobert/codegpt/settings/ServiceSelectionForm.java rename to src/main/java/ee/carlrobert/codegpt/settings/service/ServiceSelectionForm.java index 057bc7cf..75c5029c 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/ServiceSelectionForm.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceSelectionForm.java @@ -1,58 +1,46 @@ -package ee.carlrobert.codegpt.settings; +package ee.carlrobert.codegpt.settings.service; import com.intellij.openapi.Disposable; import com.intellij.openapi.application.ApplicationManager; -import com.intellij.openapi.editor.colors.EditorColorsManager; +import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.ui.ComboBox; -import com.intellij.ui.JBColor; +import com.intellij.ui.EnumComboBoxModel; import com.intellij.ui.TitledSeparator; -import com.intellij.ui.components.*; +import com.intellij.ui.components.JBCheckBox; +import com.intellij.ui.components.JBPasswordField; +import com.intellij.ui.components.JBRadioButton; +import com.intellij.ui.components.JBTextField; import com.intellij.util.ui.FormBuilder; import com.intellij.util.ui.JBUI; import com.intellij.util.ui.UI; -import com.intellij.util.ui.UIUtil; import ee.carlrobert.codegpt.CodeGPTBundle; +import ee.carlrobert.codegpt.completions.you.YouUserManager; +import ee.carlrobert.codegpt.completions.you.auth.AuthenticationNotifier; import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; -import ee.carlrobert.codegpt.settings.state.SettingsState; -import ee.carlrobert.codegpt.completions.you.YouUserManager; -import ee.carlrobert.codegpt.completions.you.auth.AuthenticationNotifier; import ee.carlrobert.codegpt.settings.state.YouSettingsState; -import ee.carlrobert.codegpt.telemetry.ui.utils.JBLabelUtils; import ee.carlrobert.codegpt.util.SwingUtils; import ee.carlrobert.llm.client.openai.completion.chat.OpenAIChatCompletionModel; -import ee.carlrobert.llm.completion.CompletionModel; - -import java.awt.*; import java.util.List; import java.util.Map; -import javax.swing.Box; import javax.swing.ButtonGroup; import javax.swing.JComponent; import javax.swing.JPanel; public class ServiceSelectionForm { + private static final Logger LOG = Logger.getInstance(ServiceSelectionForm.class); + private final Disposable parentDisposable; - private static final OpenAIChatCompletionModel[] DEFAULT_OPENAI_MODELS = new OpenAIChatCompletionModel[]{ - OpenAIChatCompletionModel.GPT_3_5, - OpenAIChatCompletionModel.GPT_3_5_16k, - OpenAIChatCompletionModel.GPT_4, - OpenAIChatCompletionModel.GPT_4_32k - }; - - private final JBRadioButton useOpenAIServiceRadioButton; - private final JBRadioButton useAzureServiceRadioButton; - private final JBPasswordField openAIApiKeyField; private final JBTextField openAIBaseHostField; private final JBTextField openAIPathField; private final JBTextField openAIOrganizationField; private final JPanel openAIServiceSectionPanel; - private final ComboBox openAICompletionModelComboBox; + private final ComboBox openAICompletionModelComboBox; private final JBRadioButton useAzureApiKeyAuthenticationRadioButton; private final JBPasswordField azureApiKeyField; @@ -66,13 +54,14 @@ public class ServiceSelectionForm { private final JBTextField azureDeploymentIdField; private final JBTextField azureApiVersionField; private final JPanel azureServiceSectionPanel; - private final ComboBox azureCompletionModelComboBox; + private final ComboBox azureCompletionModelComboBox; - private final JBRadioButton useYouServiceRadioButton; private final JPanel youServiceSectionPanel; private final JBCheckBox displayWebSearchResultsCheckBox; - public ServiceSelectionForm(Disposable parentDisposable, SettingsState settings) { + private final LlamaServiceSelectionForm llamaServiceSectionPanel; + + public ServiceSelectionForm(Disposable parentDisposable) { this.parentDisposable = parentDisposable; var openAISettings = OpenAISettingsState.getInstance(); var azureSettings = AzureSettingsState.getInstance(); @@ -83,68 +72,58 @@ public class ServiceSelectionForm { azureApiKeyField = new JBPasswordField(); azureApiKeyField.setColumns(30); azureApiKeyField.setText(AzureCredentialsManager.getInstance().getAzureOpenAIApiKey()); - azureApiKeyFieldPanel = UI.PanelFactory.panel(azureApiKeyField) - .withLabel("API key:") + .withLabel(CodeGPTBundle.get("settingsConfigurable.shared.apiKey.label")) .resizeX(false) .createPanel(); - azureActiveDirectoryTokenField = new JBPasswordField(); azureActiveDirectoryTokenField.setColumns(30); azureActiveDirectoryTokenField.setText( AzureCredentialsManager.getInstance().getAzureActiveDirectoryToken()); - azureActiveDirectoryTokenFieldPanel = UI.PanelFactory.panel(azureActiveDirectoryTokenField) - .withLabel("Bearer token:") + .withLabel(CodeGPTBundle.get("settingsConfigurable.service.azure.bearerToken.label")) .resizeX(false) .createPanel(); - useAzureApiKeyAuthenticationRadioButton = new JBRadioButton( - "Use API Key authentication", + CodeGPTBundle.get("settingsConfigurable.service.azure.useApiKeyAuth.label"), azureSettings.isUseAzureApiKeyAuthentication()); useAzureActiveDirectoryAuthenticationRadioButton = new JBRadioButton( - "Use Active Directory authentication", + CodeGPTBundle.get("settingsConfigurable.service.azure.useActiveDirectoryAuth.label"), azureSettings.isUseAzureActiveDirectoryAuthentication()); - useOpenAIServiceRadioButton = new JBRadioButton( - CodeGPTBundle.get("settingsConfigurable.section.service.useOpenAIServiceRadioButtonLabel"), - settings.isUseOpenAIService()); - useAzureServiceRadioButton = new JBRadioButton( - CodeGPTBundle.get("settingsConfigurable.section.service.useAzureServiceRadioButtonLabel"), - settings.isUseAzureService()); - useYouServiceRadioButton = new JBRadioButton( - CodeGPTBundle.get("settingsConfigurable.section.service.useYouServiceRadioButtonLabel"), - settings.isUseYouService()); - openAIBaseHostField = new JBTextField(openAISettings.getBaseHost(), 30); openAIPathField = new JBTextField(openAISettings.getPath(), 30); openAIOrganizationField = new JBTextField(openAISettings.getOrganization(), 30); - openAICompletionModelComboBox = new ModelComboBox( - DEFAULT_OPENAI_MODELS, - OpenAIChatCompletionModel.findByCode(openAISettings.getModel())); + + var selectedOpenAIModel = OpenAIChatCompletionModel.findByCode(openAISettings.getModel()); + + openAICompletionModelComboBox = new ComboBox<>( + new EnumComboBoxModel<>(OpenAIChatCompletionModel.class)); + openAICompletionModelComboBox.setSelectedItem(selectedOpenAIModel); azureBaseHostField = new JBTextField(azureSettings.getBaseHost(), 35); azurePathField = new JBTextField(azureSettings.getPath(), 35); azureResourceNameField = new JBTextField(azureSettings.getResourceName(), 35); azureDeploymentIdField = new JBTextField(azureSettings.getDeploymentId(), 35); azureApiVersionField = new JBTextField(azureSettings.getApiVersion(), 35); - azureCompletionModelComboBox = new ModelComboBox( - DEFAULT_OPENAI_MODELS, - OpenAIChatCompletionModel.findByCode(azureSettings.getModel())); + azureCompletionModelComboBox = new ComboBox<>( + new EnumComboBoxModel<>(OpenAIChatCompletionModel.class)); + azureCompletionModelComboBox.setSelectedItem(selectedOpenAIModel); azureCompletionModelComboBox.getEditor() .getEditorComponent() .setMaximumSize(azureBaseHostField.getPreferredSize()); displayWebSearchResultsCheckBox = new JBCheckBox( - "Display web search results", + CodeGPTBundle.get("settingsConfigurable.service.you.displayResults.label"), YouSettingsState.getInstance().isDisplayWebSearchResults()); displayWebSearchResultsCheckBox.setEnabled(YouUserManager.getInstance().isAuthenticated()); openAIServiceSectionPanel = createOpenAIServiceSectionPanel(); azureServiceSectionPanel = createAzureServiceSectionPanel(); youServiceSectionPanel = createYouServiceSectionPanel(); + llamaServiceSectionPanel = new LlamaServiceSelectionForm(); - registerPanelsVisibility(settings, azureSettings); + registerPanelsVisibility(azureSettings); registerRadioButtons(); ApplicationManager.getApplication() @@ -154,58 +133,42 @@ public class ServiceSelectionForm { (AuthenticationNotifier) () -> displayWebSearchResultsCheckBox.setEnabled(true)); } - public JPanel getForm() { - var panel = new JPanel(new FlowLayout(FlowLayout.LEADING, 0, 0)); - panel.add(useOpenAIServiceRadioButton); - if (OpenAISettingsState.getInstance().isOpenAIQuotaExceeded()) { - panel.add(Box.createHorizontalStrut(4)); - panel.add(new JBLabel("quota exceeded")); - } - // flow layout's horizontal gap adds annoying horizontal padding on each sides - panel.add(Box.createHorizontalStrut(16)); - panel.add(useAzureServiceRadioButton); - panel.add(Box.createHorizontalStrut(16)); - panel.add(useYouServiceRadioButton); - - return FormBuilder.createFormBuilder() - .addComponent(withEmptyLeftBorder(panel)) - .addComponent(openAIServiceSectionPanel) - .addComponent(azureServiceSectionPanel) - .addComponent(youServiceSectionPanel) - .getPanel(); - } - private JPanel createOpenAIServiceSectionPanel() { var requestConfigurationPanel = UI.PanelFactory.grid() .add(UI.PanelFactory.panel(openAIOrganizationField) .withLabel(CodeGPTBundle.get( - "settingsConfigurable.section.service.openai.organizationField.label")) + "settingsConfigurable.service.openai.organization.label")) .resizeX(false) .withComment(CodeGPTBundle.get( - "settingsConfigurable.section.service.openai.organizationField.comment"))) + "settingsConfigurable.section.openai.organization.comment"))) .add(UI.PanelFactory.panel(openAIBaseHostField) - .withLabel("Base host:") + .withLabel(CodeGPTBundle.get( + "settingsConfigurable.shared.baseHost.label")) .resizeX(false)) .add(UI.PanelFactory.panel(openAIPathField) - .withLabel("Path:") + .withLabel(CodeGPTBundle.get( + "settingsConfigurable.shared.path.label")) .resizeX(false)) .add(UI.PanelFactory.panel(openAICompletionModelComboBox) - .withLabel("Model:") + .withLabel(CodeGPTBundle.get( + "settingsConfigurable.shared.model.label")) .resizeX(false)) .createPanel(); var apiKeyFieldPanel = UI.PanelFactory.panel(openAIApiKeyField) - .withLabel(CodeGPTBundle.get("settingsConfigurable.section.integration.apiKeyField.label")) + .withLabel(CodeGPTBundle.get("settingsConfigurable.shared.apiKey.label")) .resizeX(false) .withComment( - CodeGPTBundle.get("settingsConfigurable.section.integration.apiKeyField.comment")) + CodeGPTBundle.get("settingsConfigurable.service.openai.apiKey.comment")) .withCommentHyperlinkListener(SwingUtils::handleHyperlinkClicked) .createPanel(); return FormBuilder.createFormBuilder() - .addComponent(new TitledSeparator("Authentication")) + .addComponent(new TitledSeparator( + CodeGPTBundle.get("settingsConfigurable.shared.authentication.title"))) .addComponent(withEmptyLeftBorder(apiKeyFieldPanel)) - .addComponent(new TitledSeparator("Request Configuration")) + .addComponent(new TitledSeparator( + CodeGPTBundle.get("settingsConfigurable.shared.requestConfiguration.title"))) .addComponent(withEmptyLeftBorder(requestConfigurationPanel)) .addComponentFillVertically(new JPanel(), 0) .getPanel(); @@ -228,46 +191,51 @@ public class ServiceSelectionForm { var configPanel = withEmptyLeftBorder(UI.PanelFactory.grid() .add(UI.PanelFactory.panel(azureResourceNameField) .withLabel(CodeGPTBundle.get( - "settingsConfigurable.section.service.azure.resourceNameField.label")) + "settingsConfigurable.service.azure.resourceName.label")) .resizeX(false) .withComment(CodeGPTBundle.get( - "settingsConfigurable.section.service.azure.resourceNameField.comment"))) + "settingsConfigurable.service.azure.resourceName.comment"))) .add(UI.PanelFactory.panel(azureDeploymentIdField) .withLabel(CodeGPTBundle.get( - "settingsConfigurable.section.service.azure.deploymentIdField.label")) + "settingsConfigurable.service.azure.deploymentId.label")) .resizeX(false) .withComment(CodeGPTBundle.get( - "settingsConfigurable.section.service.azure.deploymentIdField.comment"))) + "settingsConfigurable.service.azure.deploymentId.comment"))) .add(UI.PanelFactory.panel(azureApiVersionField) .withLabel(CodeGPTBundle.get( - "settingsConfigurable.section.service.azure.apiVersionField.label")) + "settingsConfigurable.service.azure.apiVersion.label")) .resizeX(false) .withComment(CodeGPTBundle.get( - "settingsConfigurable.section.service.azure.apiVersionField.comment"))) + "settingsConfigurable.service.azure.apiVersion.comment"))) .add(UI.PanelFactory.panel(azureBaseHostField) - .withLabel("Base host:") + .withLabel(CodeGPTBundle.get("settingsConfigurable.shared.baseHost.label")) .resizeX(false)) .add(UI.PanelFactory.panel(azurePathField) - .withLabel("Path:") + .withLabel(CodeGPTBundle.get("settingsConfigurable.shared.path.label")) .resizeX(false)) .add(UI.PanelFactory.panel(azureCompletionModelComboBox) - .withLabel("Model:") + .withLabel(CodeGPTBundle.get("settingsConfigurable.shared.model.label")) .resizeX(false)) .createPanel()); return FormBuilder.createFormBuilder() - .addComponent(new TitledSeparator("Authentication")) + .addComponent(new TitledSeparator( + CodeGPTBundle.get("settingsConfigurable.shared.authentication.title"))) .addComponent(authPanel) - .addComponent(new TitledSeparator("Request Configuration")) + .addComponent(new TitledSeparator( + CodeGPTBundle.get("settingsConfigurable.shared.requestConfiguration.title"))) .addComponent(configPanel) + .addComponentFillVertically(new JPanel(), 0) .getPanel(); } private JPanel createYouServiceSectionPanel() { return FormBuilder.createFormBuilder() - .addComponent(new YouServiceSelectionPanel(parentDisposable)) - .addComponent(new TitledSeparator("Chat Preferences")) + .addComponent(new YouServiceSelectionForm(parentDisposable)) + .addComponent(new TitledSeparator( + CodeGPTBundle.get("settingsConfigurable.service.you.chatPreferences.title"))) .addComponent(withEmptyLeftBorder(displayWebSearchResultsCheckBox)) + .addComponentFillVertically(new JPanel(), 0) .getPanel(); } @@ -276,32 +244,22 @@ public class ServiceSelectionForm { return component; } - private void registerPanelsVisibility(SettingsState settings, AzureSettingsState azureSettings) { - openAIServiceSectionPanel.setVisible(settings.isUseOpenAIService()); - azureServiceSectionPanel.setVisible(settings.isUseAzureService()); + private void registerPanelsVisibility(AzureSettingsState azureSettings) { azureApiKeyFieldPanel.setVisible(azureSettings.isUseAzureApiKeyAuthentication()); azureActiveDirectoryTokenFieldPanel.setVisible( azureSettings.isUseAzureActiveDirectoryAuthentication()); - youServiceSectionPanel.setVisible(settings.isUseYouService()); } private void registerRadioButtons() { - registerRadioButtons( - List.of( - Map.entry(useOpenAIServiceRadioButton, openAIServiceSectionPanel), - Map.entry(useAzureServiceRadioButton, azureServiceSectionPanel), - Map.entry(useYouServiceRadioButton, youServiceSectionPanel))); - registerRadioButtons( - List.of( - Map.entry(useAzureApiKeyAuthenticationRadioButton, azureApiKeyFieldPanel), - Map.entry(useAzureActiveDirectoryAuthenticationRadioButton, - azureActiveDirectoryTokenFieldPanel))); + registerRadioButtons(List.of( + Map.entry(useAzureApiKeyAuthenticationRadioButton, azureApiKeyFieldPanel), + Map.entry(useAzureActiveDirectoryAuthenticationRadioButton, + azureActiveDirectoryTokenFieldPanel))); } private void registerRadioButtons(List> entries) { var buttonGroup = new ButtonGroup(); entries.forEach(entry -> buttonGroup.add(entry.getKey())); - entries.forEach(entry -> entry.getKey().addActionListener((e) -> { for (Map.Entry innerEntry : entries) { innerEntry.getValue().setVisible(innerEntry.equals(entry)); @@ -309,42 +267,6 @@ public class ServiceSelectionForm { })); } - public OpenAIChatCompletionModel getSelectedCompletionModel() { - return (OpenAIChatCompletionModel) (isOpenAIServiceSelected() ? - openAICompletionModelComboBox.getSelectedItem() : - azureCompletionModelComboBox.getSelectedItem()); - } - - public void setSelectedChatCompletionModel(OpenAIChatCompletionModel chatCompletionModel) { - if (isOpenAIServiceSelected()) { - openAICompletionModelComboBox.setSelectedItem(chatCompletionModel); - } - } - - public void setOpenAIServiceSelected(boolean selected) { - useOpenAIServiceRadioButton.setSelected(selected); - } - - public boolean isOpenAIServiceSelected() { - return useOpenAIServiceRadioButton.isSelected(); - } - - public void setAzureServiceSelected(boolean selected) { - useAzureServiceRadioButton.setSelected(selected); - } - - public boolean isAzureServiceSelected() { - return useAzureServiceRadioButton.isSelected(); - } - - public boolean isYouServiceSelected() { - return useYouServiceRadioButton.isSelected(); - } - - public void setYouServiceSelected(boolean selected) { - useYouServiceRadioButton.setSelected(selected); - } - public void setOpenAIApiKey(String apiKey) { openAIApiKeyField.setText(apiKey); } @@ -461,6 +383,10 @@ public class ServiceSelectionForm { return displayWebSearchResultsCheckBox.isSelected(); } + public LlamaModelPreferencesForm getLlamaModelPreferencesForm() { + return llamaServiceSectionPanel.getLlamaModelPreferencesForm(); + } + public void setOpenAIPath(String path) { openAIPathField.setText(path); } @@ -476,4 +402,36 @@ public class ServiceSelectionForm { public String getAzurePath() { return azurePathField.getText(); } + + public void setLlamaServerPort(int serverPort) { + llamaServiceSectionPanel.setServerPort(serverPort); + } + + public int getLlamaServerPort() { + return llamaServiceSectionPanel.getServerPort(); + } + + public JPanel getOpenAIServiceSectionPanel() { + return openAIServiceSectionPanel; + } + + public JPanel getAzureServiceSectionPanel() { + return azureServiceSectionPanel; + } + + public JPanel getYouServiceSectionPanel() { + return youServiceSectionPanel; + } + + public JPanel getLlamaServiceSectionPanel() { + return llamaServiceSectionPanel; + } + + public int getContextSize() { + return llamaServiceSectionPanel.getContextSize(); + } + + public void setContextSize(int contextSize) { + llamaServiceSectionPanel.setContextSize(contextSize); + } } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java b/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java new file mode 100644 index 00000000..764ed33e --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/ServiceType.java @@ -0,0 +1,31 @@ +package ee.carlrobert.codegpt.settings.service; + +import ee.carlrobert.codegpt.CodeGPTBundle; + +public enum ServiceType { + OPENAI("OPENAI", CodeGPTBundle.get("service.openai.title")), + AZURE("AZURE", CodeGPTBundle.get("service.azure.title")), + YOU("YOU", CodeGPTBundle.get("service.you.title")), + LLAMA_CPP("LLAMA_CPP", CodeGPTBundle.get("service.llama.title")); + + private final String code; + private final String label; + + ServiceType(String code, String label) { + this.code = code; + this.label = label; + } + + public String getCode() { + return code; + } + + public String getLabel() { + return label; + } + + @Override + public String toString() { + return label; + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/YouServiceSelectionPanel.java b/src/main/java/ee/carlrobert/codegpt/settings/service/YouServiceSelectionForm.java similarity index 78% rename from src/main/java/ee/carlrobert/codegpt/settings/YouServiceSelectionPanel.java rename to src/main/java/ee/carlrobert/codegpt/settings/service/YouServiceSelectionForm.java index 89bdd04c..72c3e916 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/YouServiceSelectionPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/service/YouServiceSelectionForm.java @@ -1,4 +1,4 @@ -package ee.carlrobert.codegpt.settings; +package ee.carlrobert.codegpt.settings.service; import com.intellij.openapi.Disposable; import com.intellij.openapi.ui.ComponentValidator; @@ -9,7 +9,6 @@ import com.intellij.ui.TitledSeparator; import com.intellij.ui.components.JBLabel; import com.intellij.ui.components.JBPasswordField; import com.intellij.ui.components.JBTextField; -import com.intellij.ui.components.OnOffButton; import com.intellij.util.ui.AsyncProcessIcon; import com.intellij.util.ui.FormBuilder; import com.intellij.util.ui.JBFont; @@ -35,7 +34,7 @@ import javax.swing.JTextPane; import javax.swing.SwingUtilities; import org.jetbrains.annotations.Nullable; -public class YouServiceSelectionPanel extends JPanel { +public class YouServiceSelectionForm extends JPanel { private final JBTextField emailField; private final JBPasswordField passwordField; @@ -43,7 +42,7 @@ public class YouServiceSelectionPanel extends JPanel { private final JTextPane signUpTextPane; private final AsyncProcessIcon loadingSpinner; - public YouServiceSelectionPanel(Disposable parentDisposable) { + public YouServiceSelectionForm(Disposable parentDisposable) { super(new BorderLayout()); var settings = SettingsState.getInstance(); emailField = new JBTextField(settings.getEmail(), 25); @@ -52,8 +51,7 @@ public class YouServiceSelectionPanel extends JPanel { if (!settings.getEmail().isEmpty()) { passwordField.setText(YouCredentialsManager.getInstance().getAccountPassword()); } - signInButton = new JButton( - CodeGPTBundle.get("settingsConfigurable.section.userAuthentication.signIn.label")); + signInButton = new JButton(CodeGPTBundle.get("settingsConfigurable.service.you.signIn.label")); signUpTextPane = createSignUpTextPane(); loadingSpinner = new AsyncProcessIcon("sign_in_spinner"); loadingSpinner.setBorder(JBUI.Borders.emptyLeft(8)); @@ -106,7 +104,8 @@ public class YouServiceSelectionPanel extends JPanel { if (component instanceof JBTextField) { value = ((JBTextField) component).getText(); if (!isValidEmail(value)) { - return new ValidationInfo("The email you entered is invalid.", component) + return new ValidationInfo( + CodeGPTBundle.get("validation.error.invalidEmail"), component) .withOKEnabled(); } } else { @@ -114,7 +113,9 @@ public class YouServiceSelectionPanel extends JPanel { } if (StringUtil.isEmpty(value)) { - return new ValidationInfo("This field is required.", component).withOKEnabled(); + return new ValidationInfo( + CodeGPTBundle.get("validation.error.fieldRequired"), component) + .withOKEnabled(); } return null; @@ -134,18 +135,14 @@ public class YouServiceSelectionPanel extends JPanel { private JTextPane createSignUpTextPane() { var textPane = createTextPane( - "Don't have an account?
Sign up with 'CodeGPT' coupon for free GPT-4
"); + "Don't have an account? Sign up"); textPane.setBorder(JBUI.Borders.emptyLeft(4)); return textPane; } private JTextPane createTextPane(String htmlContent) { - var textPane = new JTextPane(); - textPane.setContentType("text/html"); - textPane.putClientProperty(JTextPane.HONOR_DISPLAY_PROPERTIES, true); + var textPane = SwingUtils.createTextPane(SwingUtils::handleHyperlinkClicked); textPane.setText(htmlContent); - textPane.addHyperlinkListener(SwingUtils::handleHyperlinkClicked); - textPane.setEditable(false); return textPane; } @@ -162,9 +159,23 @@ public class YouServiceSelectionPanel extends JPanel { JBTextField emailAddressField, JBPasswordField passwordField, @Nullable YouAuthenticationError error) { + var couponLabel = new JBLabel( + "" + + "" + + "

Free GPT-4

" + + "

Your coupon code

" + + "

CODEGPT

" + + "" + + "") + .withBorder(JBUI.Borders.emptyLeft(45)) // TODO + .setCopyable(true); + var contentPanelBuilder = FormBuilder.createFormBuilder() - .addLabeledComponent("Email address:", emailAddressField) - .addLabeledComponent("Password:", passwordField) + .addComponentToRightColumn(JBUI.Panels.simplePanel().addToLeft(couponLabel)) + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.you.email.label"), + emailAddressField) + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.you.password.label"), + passwordField) .addVerticalGap(4) .addComponentToRightColumn(createFooterPanel()) .addVerticalGap(4); @@ -176,22 +187,24 @@ public class YouServiceSelectionPanel extends JPanel { contentPanelBuilder.addComponentToRightColumn(invalidCredentialsLabel); } + var contentPanel = contentPanelBuilder.getPanel(); + contentPanel.setBorder(JBUI.Borders.emptyLeft(16)); + return FormBuilder.createFormBuilder() .addComponent(new TitledSeparator( - CodeGPTBundle.get("settingsConfigurable.section.userAuthentication.title"))) - .addComponent(JBUI.Panels - .simplePanel(contentPanelBuilder.getPanel()) - .withBorder(JBUI.Borders.emptyLeft(16))) + CodeGPTBundle.get("settingsConfigurable.service.you.authentication.title"))) + .addComponent(contentPanel) .getPanel(); } private JPanel createUserInformationPanel(YouUser user) { var userManager = YouUserManager.getInstance(); var contentPanelBuilder = FormBuilder.createFormBuilder() - .addLabeledComponent("Email address:", + .addLabeledComponent(CodeGPTBundle.get("settingsConfigurable.service.you.email.label"), new JBLabel(user.getEmails().get(0).getEmail()).withFont(JBFont.label().asBold())); - var signOutButton = new JButton("Sign Out"); + var signOutButton = new JButton( + CodeGPTBundle.get("settingsConfigurable.service.you.signOut.label")); signOutButton.addActionListener(e -> { userManager.clearSession(); refreshView(createUserAuthenticationPanel(emailField, passwordField, null)); @@ -199,7 +212,7 @@ public class YouServiceSelectionPanel extends JPanel { return FormBuilder.createFormBuilder() .addComponent(new TitledSeparator( - CodeGPTBundle.get("settingsConfigurable.section.userInformation.title"))) + CodeGPTBundle.get("settingsConfigurable.service.you.userInformation.title"))) .addVerticalGap(8) .addComponent(JBUI.Panels .simplePanel(contentPanelBuilder.addVerticalGap(4) diff --git a/src/main/java/ee/carlrobert/codegpt/settings/state/AzureSettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/state/AzureSettingsState.java index 44b16edb..efbbfaf2 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/state/AzureSettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/state/AzureSettingsState.java @@ -6,7 +6,7 @@ import com.intellij.openapi.components.State; import com.intellij.openapi.components.Storage; import com.intellij.util.xmlb.XmlSerializerUtil; import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; -import ee.carlrobert.codegpt.settings.ServiceSelectionForm; +import ee.carlrobert.codegpt.settings.service.ServiceSelectionForm; import ee.carlrobert.llm.client.openai.completion.chat.OpenAIChatCompletionModel; import org.jetbrains.annotations.NotNull; diff --git a/src/main/java/ee/carlrobert/codegpt/settings/state/LlamaSettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/state/LlamaSettingsState.java new file mode 100644 index 00000000..5b505329 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/settings/state/LlamaSettingsState.java @@ -0,0 +1,96 @@ +package ee.carlrobert.codegpt.settings.state; + +import com.intellij.openapi.application.ApplicationManager; +import com.intellij.openapi.components.PersistentStateComponent; +import com.intellij.openapi.components.State; +import com.intellij.openapi.components.Storage; +import com.intellij.util.xmlb.XmlSerializerUtil; +import ee.carlrobert.codegpt.completions.HuggingFaceModel; +import ee.carlrobert.codegpt.completions.llama.PromptTemplate; +import java.io.IOException; +import java.net.ServerSocket; +import org.jetbrains.annotations.NotNull; + +@State(name = "CodeGPT_LlamaSettings", storages = @Storage("CodeGPT_CodeGPT_LlamaSettings.xml")) +public class LlamaSettingsState implements PersistentStateComponent { + + private boolean useCustomModel; + private String customLlamaModelPath = ""; + private HuggingFaceModel huggingFaceModel = HuggingFaceModel.CODE_LLAMA_7B_Q4; + private PromptTemplate promptTemplate = PromptTemplate.LLAMA; + private int serverPort = getRandomAvailablePortOrDefault(); + private int contextSize = 2048; + + public LlamaSettingsState() { + } + + public static LlamaSettingsState getInstance() { + return ApplicationManager.getApplication().getService(LlamaSettingsState.class); + } + + @Override + public LlamaSettingsState getState() { + return this; + } + + @Override + public void loadState(@NotNull LlamaSettingsState state) { + XmlSerializerUtil.copyBean(state, this); + } + + public boolean isUseCustomModel() { + return useCustomModel; + } + + public void setUseCustomModel(boolean useCustomModel) { + this.useCustomModel = useCustomModel; + } + + public String getCustomLlamaModelPath() { + return customLlamaModelPath; + } + + public void setCustomLlamaModelPath(String customLlamaModelPath) { + this.customLlamaModelPath = customLlamaModelPath; + } + + public HuggingFaceModel getHuggingFaceModel() { + return huggingFaceModel; + } + + public void setHuggingFaceModel(HuggingFaceModel huggingFaceModel) { + this.huggingFaceModel = huggingFaceModel; + } + + public PromptTemplate getPromptTemplate() { + return promptTemplate; + } + + public void setPromptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + } + + public int getServerPort() { + return serverPort; + } + + public void setServerPort(int serverPort) { + this.serverPort = serverPort; + } + + public int getContextSize() { + return contextSize; + } + + public void setContextSize(int contextSize) { + this.contextSize = contextSize; + } + + private static Integer getRandomAvailablePortOrDefault() { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } catch (IOException e) { + return 8080; + } + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/settings/state/OpenAISettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/state/OpenAISettingsState.java index d841cb2f..aeee5461 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/state/OpenAISettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/state/OpenAISettingsState.java @@ -6,7 +6,7 @@ import com.intellij.openapi.components.State; import com.intellij.openapi.components.Storage; import com.intellij.util.xmlb.XmlSerializerUtil; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; -import ee.carlrobert.codegpt.settings.ServiceSelectionForm; +import ee.carlrobert.codegpt.settings.service.ServiceSelectionForm; import ee.carlrobert.llm.client.openai.completion.chat.OpenAIChatCompletionModel; import org.jetbrains.annotations.NotNull; diff --git a/src/main/java/ee/carlrobert/codegpt/settings/state/SettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/state/SettingsState.java index b1b43fd1..feb1183a 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/state/SettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/state/SettingsState.java @@ -5,6 +5,7 @@ import com.intellij.openapi.components.PersistentStateComponent; import com.intellij.openapi.components.State; import com.intellij.openapi.components.Storage; import com.intellij.util.xmlb.XmlSerializerUtil; +import ee.carlrobert.codegpt.completions.HuggingFaceModel; import ee.carlrobert.codegpt.conversations.Conversation; import org.jetbrains.annotations.NotNull; @@ -17,6 +18,7 @@ public class SettingsState implements PersistentStateComponent { private boolean useOpenAIService = true; private boolean useAzureService; private boolean useYouService; + private boolean useLlamaService; public SettingsState() { } @@ -43,10 +45,15 @@ public class SettingsState implements PersistentStateComponent { if ("azure.chat.completion".equals(clientCode)) { AzureSettingsState.getInstance().setModel(conversation.getModel()); } + if ("llama.chat.completion".equals(clientCode)) { + LlamaSettingsState.getInstance().setHuggingFaceModel( + HuggingFaceModel.valueOf(conversation.getModel())); + } setUseOpenAIService("chat.completion".equals(clientCode)); setUseAzureService("azure.chat.completion".equals(clientCode)); setUseYouService("you.chat.completion".equals(clientCode)); + setUseLlamaService("llama.chat.completion".equals(clientCode)); } public String getEmail() { @@ -103,4 +110,12 @@ public class SettingsState implements PersistentStateComponent { public void setUseYouService(boolean useYouService) { this.useYouService = useYouService; } + + public boolean isUseLlamaService() { + return useLlamaService; + } + + public void setUseLlamaService(boolean useLlamaService) { + this.useLlamaService = useLlamaService; + } } diff --git a/src/main/java/ee/carlrobert/codegpt/settings/state/YouSettingsState.java b/src/main/java/ee/carlrobert/codegpt/settings/state/YouSettingsState.java index d91201b4..4140672c 100644 --- a/src/main/java/ee/carlrobert/codegpt/settings/state/YouSettingsState.java +++ b/src/main/java/ee/carlrobert/codegpt/settings/state/YouSettingsState.java @@ -12,6 +12,7 @@ public class YouSettingsState implements PersistentStateComponent\n" + + "\n" + + "

Use CodeGPT coupon for free month of GPT-4.

\n" + + "

\n" + + " Sign up here\n" + + "

\n" + + "\n" + + ""); + return textPane; + } + @Override public void startNewConversation(Message message) { conversation = conversationService.startConversation(); @@ -165,7 +195,9 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan requestHandler.withContextualSearch(useContextualSearch); requestHandler.addMessageListener(partialMessage -> { try { - responseContainer.update(partialMessage); + LOG.debug(partialMessage); + ApplicationManager.getApplication() + .invokeLater(() -> responseContainer.update(partialMessage)); } catch (Exception e) { responseContainer.displayDefaultError(); throw new RuntimeException("Error while updating the content", e); @@ -354,23 +386,33 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan JBUI.Borders.empty(8))); wrapper.setBackground(getPanelBackgroundColor()); wrapper.add(userPromptTextArea, BorderLayout.SOUTH); + if (model != null) { var header = new JPanel(new BorderLayout()); header.setBackground(getPanelBackgroundColor()); header.setBorder(JBUI.Borders.emptyBottom(8)); if ("YouCode".equals(model)) { + var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect(); subscribeToYouModelChangeTopic(); - subscribeToYouAuthTopic(); + subscribeToYouSubscriptionTopic(messageBusConnection); + subscribeToSignedOutTopic(messageBusConnection); header.add(gpt4CheckBox, BorderLayout.LINE_START); } header.add(modelIconWrapper, BorderLayout.LINE_END); wrapper.add(header); } + rootPanel.add(wrapper, gbc); userPromptTextArea.requestFocusInWindow(); userPromptTextArea.requestFocus(); } + private void subscribeToSignedOutTopic(MessageBusConnection messageBusConnection) { + messageBusConnection.subscribe( + SignedOutNotifier.SIGNED_OUT_TOPIC, + (SignedOutNotifier) () -> gpt4CheckBox.setEnabled(false)); + } + private void subscribeToYouModelChangeTopic() { project.getMessageBus() .connect() @@ -379,24 +421,26 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan (YouModelChangeNotifier) gpt4CheckBox::setSelected); } - private void subscribeToYouAuthTopic() { - ApplicationManager.getApplication() - .getMessageBus() - .connect() - .subscribe(AuthenticationNotifier.AUTHENTICATION_TOPIC, - (AuthenticationNotifier) () -> gpt4CheckBox.setEnabled(true)); + private void subscribeToYouSubscriptionTopic(MessageBusConnection messageBusConnection) { + messageBusConnection.subscribe( + YouSubscriptionNotifier.SUBSCRIPTION_TOPIC, + (YouSubscriptionNotifier) () -> { + displayLandingView(); + gpt4CheckBox.setEnabled(true); + }); } private JBCheckBox createGPT4ModelCheckBox() { - var gpt4CheckBox = new JBCheckBox("Use GPT-4 model"); + var gpt4CheckBox = new JBCheckBox(CodeGPTBundle.get("toolwindow.chat.youProCheckBox.text")); gpt4CheckBox.setOpaque(false); - gpt4CheckBox.setEnabled(YouUserManager.getInstance().isAuthenticated()); + gpt4CheckBox.setEnabled(YouUserManager.getInstance().isSubscribed()); gpt4CheckBox.setSelected(YouSettingsState.getInstance().isUseGPT4Model()); gpt4CheckBox.setToolTipText(getTooltipText(gpt4CheckBox.isSelected())); gpt4CheckBox.addChangeListener(e -> { var selected = ((JBCheckBox) e.getSource()).isSelected(); var tooltipText = getTooltipText(selected); gpt4CheckBox.setToolTipText(tooltipText); + // TODO: Remove project.getMessageBus() .syncPublisher(YouModelChangeNotifier.YOU_MODEL_CHANGE_NOTIFIER_TOPIC) .modelChanged(selected); @@ -406,9 +450,12 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan } private String getTooltipText(boolean selected) { - return selected ? - "Turn off for faster responses" : - "Turn on for complex queries, enable by creating an account on you.com
and signing in from plugin settings.
Use CodeGPT coupon for free month of GPT-4."; + if (YouUserManager.getInstance().isSubscribed()) { + return selected ? + CodeGPTBundle.get("toolwindow.chat.youProCheckBox.disable") : + CodeGPTBundle.get("toolwindow.chat.youProCheckBox.enable"); + } + return CodeGPTBundle.get("toolwindow.chat.youProCheckBox.notAllowed"); } private String getClientCode() { @@ -422,6 +469,9 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan if (settings.isUseYouService()) { return "you.chat.completion"; } + if (settings.isUseLlamaService()) { + return "llama.chat.completion"; + } return null; } @@ -436,7 +486,25 @@ public abstract class BaseChatToolWindowTabPanel implements ChatToolWindowTabPan if (settings.isUseYouService()) { return "YouCode"; } + if (settings.isUseLlamaService()) { + var llamaSettings = LlamaSettingsState.getInstance(); + if (llamaSettings.isUseCustomModel()) { + var filePath = llamaSettings.getCustomLlamaModelPath(); + int lastSeparatorIndex = filePath.lastIndexOf('/'); + if (lastSeparatorIndex == -1) { + return filePath; + } + return filePath.substring(lastSeparatorIndex + 1); + } + var huggingFaceModel = llamaSettings.getHuggingFaceModel(); + var llamaModel = LlamaModel.findByHuggingFaceModel(huggingFaceModel); + return String.format( + "%s %dB (Q%d)", + llamaModel.getLabel(), + huggingFaceModel.getParameterSize(), + huggingFaceModel.getQuantization()); + } - return null; + return "Unknown"; } } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/StreamParser.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/StreamParser.java index 10dacffe..4a159219 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/StreamParser.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/StreamParser.java @@ -11,6 +11,7 @@ public class StreamParser { private boolean isProcessingCode; public List parse(String message) { + message = message.replace("\r", ""); messageBuilder.append(message); Pattern pattern = Pattern.compile(CODE_BLOCK_STARTING_REGEX); diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ChatMessageResponseBody.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ChatMessageResponseBody.java index b38c3d30..f7efdf3c 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ChatMessageResponseBody.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ChatMessageResponseBody.java @@ -22,7 +22,6 @@ import com.vladsch.flexmark.util.data.MutableDataSet; import ee.carlrobert.codegpt.actions.ActionType; import ee.carlrobert.codegpt.completions.you.YouSerpResult; import ee.carlrobert.codegpt.settings.SettingsConfigurable; -import ee.carlrobert.codegpt.settings.state.SettingsState; import ee.carlrobert.codegpt.telemetry.TelemetryAction; import ee.carlrobert.codegpt.toolwindow.chat.ResponseNodeRenderer; import ee.carlrobert.codegpt.toolwindow.chat.StreamParser; @@ -231,13 +230,13 @@ public class ChatMessageResponseBody extends JPanel { add(currentlyProcessedElement); } - private void prepareProcessingCodeResponse(String code, String language) { + private void prepareProcessingCodeResponse(String code, String markdownLanguage) { hideCarets(); currentlyProcessedTextPane = null; currentlyProcessedEditor = new ResponseEditor( project, code, - language, + markdownLanguage, parentDisposable); currentlyProcessedElement = new ResponseWrapper(); @@ -249,15 +248,13 @@ public class ChatMessageResponseBody extends JPanel { var editor = currentlyProcessedEditor.getEditor(); var document = editor.getDocument(); var application = ApplicationManager.getApplication(); - Runnable updateDocumentRunnable = () -> { - application.runWriteAction(() -> - WriteCommandAction.runWriteCommandAction(project, () -> { - document.replaceString(0, document.getTextLength(), code); - editor.getCaretModel().moveToOffset(code.length()); - editor.getComponent().revalidate(); - editor.getComponent().repaint(); - })); - }; + Runnable updateDocumentRunnable = () -> application.runWriteAction(() -> + WriteCommandAction.runWriteCommandAction(project, () -> { + document.replaceString(0, document.getTextLength(), code); + editor.getCaretModel().moveToOffset(code.length()); + editor.getComponent().revalidate(); + editor.getComponent().repaint(); + })); if (application.isUnitTestMode()) { application.invokeAndWait(updateDocumentRunnable); @@ -267,8 +264,7 @@ public class ChatMessageResponseBody extends JPanel { } private JTextPane createTextPane() { - var textPane = new JTextPane(); - textPane.addHyperlinkListener(event -> { + var textPane = SwingUtils.createTextPane(event -> { if (FileUtil.exists(event.getDescription()) && ACTIVATED.equals(event.getEventType())) { VirtualFile file = LocalFileSystem.getInstance().findFileByPath(event.getDescription()); FileEditorManager.getInstance(project).openFile(Objects.requireNonNull(file), true); @@ -277,14 +273,10 @@ public class ChatMessageResponseBody extends JPanel { SwingUtils.handleHyperlinkClicked(event); }); - textPane.setContentType("text/html"); - textPane.putClientProperty(JTextPane.HONOR_DISPLAY_PROPERTIES, true); - textPane.setCaretPosition(textPane.getDocument().getLength()); - textPane.setBackground(getBackground()); - textPane.setFocusable(true); textPane.getCaret().setVisible(true); - textPane.setEditable(false); + textPane.setCaretPosition(textPane.getDocument().getLength()); textPane.setBorder(JBUI.Borders.empty()); + textPane.setBackground(getBackground()); return textPane; } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ResponsePanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ResponsePanel.java index ef58c2da..fbfc4a19 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ResponsePanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/ResponsePanel.java @@ -9,6 +9,7 @@ import com.intellij.ui.JBColor; import com.intellij.ui.components.JBLabel; import com.intellij.util.ui.JBFont; import com.intellij.util.ui.JBUI; +import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.Icons; import java.awt.BorderLayout; import java.awt.FlowLayout; @@ -82,7 +83,10 @@ public class ResponsePanel extends JPanel { public void addReloadAction(Runnable onReload) { addIconActionButton(new IconActionButton( - new AnAction("Reload Response", "Reload response description", Actions.Refresh) { + new AnAction( + CodeGPTBundle.get("toolwindow.chat.response.action.reloadResponse.text"), + CodeGPTBundle.get("toolwindow.chat.response.action.reloadResponse.description"), + Actions.Refresh) { @Override public void actionPerformed(@NotNull AnActionEvent e) { enableActions(false); @@ -93,7 +97,10 @@ public class ResponsePanel extends JPanel { public void addDeleteAction(Runnable onDelete) { addIconActionButton(new IconActionButton( - new AnAction("Delete Response", "Delete response description", Actions.GC) { + new AnAction( + CodeGPTBundle.get("toolwindow.chat.response.action.deleteResponse.text"), + CodeGPTBundle.get("toolwindow.chat.response.action.deleteResponse.description"), + Actions.GC) { @Override public void actionPerformed(@NotNull AnActionEvent e) { onDelete.run(); @@ -109,7 +116,10 @@ public class ResponsePanel extends JPanel { } private JBLabel getIconLabel() { - return new JBLabel("CodeGPT", Icons.DefaultIcon, SwingConstants.LEADING) + return new JBLabel( + CodeGPTBundle.get("project.label"), + Icons.DefaultIcon, + SwingConstants.LEADING) .setAllowAutoWrapping(true) .withFont(JBFont.label().asBold()); } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextArea.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextArea.java index 41d16951..bfa71aa4 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextArea.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/components/UserPromptTextArea.java @@ -7,6 +7,7 @@ import com.intellij.ui.JBColor; import com.intellij.ui.components.JBTextArea; import com.intellij.util.ui.JBUI; import com.intellij.util.ui.UIUtil; +import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.Icons; import ee.carlrobert.codegpt.completions.CompletionRequestHandler; import ee.carlrobert.codegpt.util.SwingUtils; @@ -55,7 +56,7 @@ public class UserPromptTextArea extends JPanel { textArea.setBackground(BACKGROUND_COLOR); textArea.setLineWrap(true); textArea.setWrapStyleWord(true); - textArea.getEmptyText().setText("Ask me anything"); + textArea.getEmptyText().setText(CodeGPTBundle.get("toolwindow.chat.textArea.emptyText")); textArea.setBorder(JBUI.Borders.empty(8, 4)); var input = textArea.getInputMap(); input.put(KeyStroke.getKeyStroke("ENTER"), TEXT_SUBMIT); @@ -171,6 +172,7 @@ public class UserPromptTextArea extends JPanel { } } + // TODO: IconActionButton? private JButton createIconButton(Icon icon, @Nullable Runnable submitListener) { var button = SwingUtils.createIconButton(icon); if (submitListener != null) { diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/contextual/ContextualChatToolWindowLandingPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/contextual/ContextualChatToolWindowLandingPanel.java index 2960631b..f8892c52 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/contextual/ContextualChatToolWindowLandingPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/contextual/ContextualChatToolWindowLandingPanel.java @@ -4,7 +4,6 @@ import static com.intellij.openapi.ui.DialogWrapper.OK_EXIT_CODE; import static ee.carlrobert.codegpt.util.ThemeUtils.getPanelBackgroundColor; import static javax.swing.event.HyperlinkEvent.EventType.ACTIVATED; -import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.options.ShowSettingsUtil; import com.intellij.openapi.project.Project; @@ -14,9 +13,6 @@ import ee.carlrobert.codegpt.indexes.CodebaseIndexingTask; import ee.carlrobert.codegpt.indexes.FolderStructureTreePanel; import ee.carlrobert.codegpt.settings.SettingsConfigurable; import ee.carlrobert.codegpt.toolwindow.chat.components.ResponsePanel; -import ee.carlrobert.codegpt.completions.you.YouUserManager; -import ee.carlrobert.codegpt.completions.you.auth.AuthenticationNotifier; -import ee.carlrobert.codegpt.completions.you.auth.SignedOutNotifier; import ee.carlrobert.codegpt.util.OverlayUtils; import ee.carlrobert.codegpt.util.SwingUtils; import ee.carlrobert.vector.VectorStore; @@ -25,6 +21,7 @@ import javax.swing.event.HyperlinkEvent; @FunctionalInterface interface ActionEvent { + void handleAction(String prompt); } @@ -43,43 +40,30 @@ class ContextualChatToolWindowLandingPanel extends ResponsePanel { .connect() .subscribe(CodebaseIndexingCompletedNotifier.INDEXING_COMPLETED_TOPIC, (CodebaseIndexingCompletedNotifier) () -> updateContent(createContent())); - - var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect(); - messageBusConnection.subscribe(AuthenticationNotifier.AUTHENTICATION_TOPIC, (AuthenticationNotifier) () -> updateContent(createContent())); - messageBusConnection.subscribe(SignedOutNotifier.SIGNED_OUT_TOPIC, (SignedOutNotifier) () -> updateContent(createContent())); } private JTextPane createContent() { var description = createTextPane(); - var userManager = YouUserManager.getInstance(); - - if (userManager.getAuthenticationResponse() == null) { - description.setText("" + - "

It looks like you haven't logged in. Please log in to use the feature.

" + - ""); - return description; - } - - if (!userManager.isSubscribed()) { - description.setText("" + - "

You are not currently subscribed to any plan.

" + - ""); - return description; - } - if (VectorStore.getInstance(CodeGPTPlugin.getPluginBasePath()).isIndexExists()) { description.setText("" + - "

Feel free to ask me anything about your codebase, and I'll be your helpful guide, dedicated to providing you with the best answers possible!

" + - "

Here are a few examples of how I might be helpful:

" + + "

Feel free to ask me anything about your codebase, and I'll be your helpful guide, dedicated to providing you with the best answers possible!

" + + + "

Here are a few examples of how I might be helpful:

" + + "
    " + - "
  • List all the dependencies that the project usesAre there any scheduled tasks or background jobs running in our codebase, and if so, what are they responsible for?
  • " + - "
  • Can you provide an overview of the authentication and authorization mechanism implemented in our application?
  • " + + "
  • List all the dependencies that the project usesAre there any scheduled tasks or background jobs running in our codebase, and if so, what are they responsible for?
  • " + + + "
  • Can you provide an overview of the authentication and authorization mechanism implemented in our application?
  • " + + ""); } else { description.setText("" + - "

    It looks like you haven't indexed your codebase yet.

    " + - "

    Start indexing your codebase to get access to contextual chat experience.

    " + + "

    It looks like you haven't indexed your codebase yet.

    " + + + "

    Start indexing your codebase to get access to contextual chat experience.

    " + + ""); } @@ -87,13 +71,8 @@ class ContextualChatToolWindowLandingPanel extends ResponsePanel { } private JTextPane createTextPane() { - var textPane = new JTextPane(); - textPane.addHyperlinkListener(this::handleHyperlinkClicked); + var textPane = SwingUtils.createTextPane(this::handleHyperlinkClicked); textPane.setBackground(getPanelBackgroundColor()); - textPane.setContentType("text/html"); - textPane.putClientProperty(JTextPane.HONOR_DISPLAY_PROPERTIES, true); - textPane.setFocusable(false); - textPane.setEditable(false); return textPane; } @@ -112,7 +91,8 @@ class ContextualChatToolWindowLandingPanel extends ResponsePanel { "Are there any scheduled tasks or background jobs running in our codebase, and if so, what are they responsible for?"); break; case "AUTHENTICATION_MECHANISM": - actionEvent.handleAction("Can you provide an overview of the authentication and authorization mechanism implemented in our application?"); + actionEvent.handleAction( + "Can you provide an overview of the authentication and authorization mechanism implemented in our application?"); break; case "START_INDEXING": var folderStructureTreePanel = new FolderStructureTreePanel(project); diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/contextual/ContextualChatToolWindowTabPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/contextual/ContextualChatToolWindowTabPanel.java index 5fb8d5f1..cc88a321 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/contextual/ContextualChatToolWindowTabPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/contextual/ContextualChatToolWindowTabPanel.java @@ -1,14 +1,9 @@ package ee.carlrobert.codegpt.toolwindow.chat.contextual; -import com.intellij.openapi.application.ApplicationManager; import com.intellij.openapi.project.Project; -import ee.carlrobert.codegpt.completions.you.YouUserManager; import ee.carlrobert.codegpt.conversations.Conversation; import ee.carlrobert.codegpt.conversations.message.Message; -import ee.carlrobert.codegpt.indexes.CodebaseIndexingCompletedNotifier; import ee.carlrobert.codegpt.toolwindow.chat.BaseChatToolWindowTabPanel; -import ee.carlrobert.codegpt.completions.you.auth.AuthenticationNotifier; -import ee.carlrobert.codegpt.completions.you.auth.SignedOutNotifier; import javax.swing.JComponent; import org.jetbrains.annotations.NotNull; @@ -17,19 +12,6 @@ public class ContextualChatToolWindowTabPanel extends BaseChatToolWindowTabPanel public ContextualChatToolWindowTabPanel(@NotNull Project project) { super(project, true); displayLandingView(); - userPromptTextArea.setTextAreaEnabled(YouUserManager.getInstance().isSubscribed()); - - project.getMessageBus() - .connect() - .subscribe(CodebaseIndexingCompletedNotifier.INDEXING_COMPLETED_TOPIC, - (CodebaseIndexingCompletedNotifier) () -> userPromptTextArea.setTextAreaEnabled( - YouUserManager.getInstance().isSubscribed())); - - var messageBusConnection = ApplicationManager.getApplication().getMessageBus().connect(); - messageBusConnection.subscribe(AuthenticationNotifier.AUTHENTICATION_TOPIC, - (AuthenticationNotifier) () -> userPromptTextArea.setTextAreaEnabled( - YouUserManager.getInstance().isSubscribed())); - messageBusConnection.subscribe(SignedOutNotifier.SIGNED_OUT_TOPIC, (SignedOutNotifier) () -> userPromptTextArea.setTextAreaEnabled(false)); } @Override diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/editor/actions/CopyAction.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/editor/actions/CopyAction.java index 46dea8f3..1578c83e 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/editor/actions/CopyAction.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/editor/actions/CopyAction.java @@ -33,6 +33,8 @@ public class CopyAction extends TrackableAction { var locationOnScreen = ((MouseEvent) event.getInputEvent()).getLocationOnScreen(); locationOnScreen.y = locationOnScreen.y - 16; - OverlayUtils.showInfoBalloon("Code copied!", locationOnScreen); + OverlayUtils.showInfoBalloon( + CodeGPTBundle.get("toolwindow.chat.editor.action.copy.success"), + locationOnScreen); } } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/editor/actions/EditAction.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/editor/actions/EditAction.java index 36d58873..c524c75c 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/editor/actions/EditAction.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/editor/actions/EditAction.java @@ -34,7 +34,9 @@ public class EditAction extends TrackableAction { settings.setCaretRowShown(!viewer); event.getPresentation().setIcon(viewer ? Actions.EditSource : Actions.Show); - event.getPresentation().setText(viewer ? "Edit Source" : "Disable Editing"); + event.getPresentation().setText(viewer ? + CodeGPTBundle.get("toolwindow.chat.editor.action.edit.title") : + CodeGPTBundle.get("toolwindow.chat.editor.action.disableEditing.title")); var locationOnScreen = ((MouseEvent) event.getInputEvent()).getLocationOnScreen(); locationOnScreen.y = locationOnScreen.y - 16; diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowLandingPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowLandingPanel.java index 80923b7f..b13c7fb3 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowLandingPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowLandingPanel.java @@ -41,13 +41,8 @@ class StandardChatToolWindowLandingPanel extends ResponsePanel { } private JTextPane createTextPane() { - var textPane = new JTextPane(); - textPane.addHyperlinkListener(this::handleHyperlinkClicked); + var textPane = SwingUtils.createTextPane(this::handleHyperlinkClicked); textPane.setBackground(getPanelBackgroundColor()); - textPane.setContentType("text/html"); - textPane.putClientProperty(JTextPane.HONOR_DISPLAY_PROPERTIES, true); - textPane.setFocusable(false); - textPane.setEditable(false); return textPane; } diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowPanel.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowPanel.java index bc04ce9c..22dd8dd8 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowPanel.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/standard/StandardChatToolWindowPanel.java @@ -3,8 +3,6 @@ package ee.carlrobert.codegpt.toolwindow.chat.standard; import com.intellij.openapi.Disposable; import com.intellij.openapi.actionSystem.ActionManager; import com.intellij.openapi.actionSystem.ActionToolbar; -import com.intellij.openapi.actionSystem.Constraints; -import com.intellij.openapi.actionSystem.DefaultActionGroup; import com.intellij.openapi.actionSystem.DefaultCompactActionGroup; import com.intellij.openapi.project.Project; import com.intellij.openapi.ui.SimpleToolWindowPanel; diff --git a/src/main/java/ee/carlrobert/codegpt/util/DownloadingUtils.java b/src/main/java/ee/carlrobert/codegpt/util/DownloadingUtils.java new file mode 100644 index 00000000..2d3992a0 --- /dev/null +++ b/src/main/java/ee/carlrobert/codegpt/util/DownloadingUtils.java @@ -0,0 +1,37 @@ +package ee.carlrobert.codegpt.util; + +import static java.lang.String.format; + +import ee.carlrobert.codegpt.util.file.FileUtils; + +public class DownloadingUtils { + + private static final int BYTES_IN_MB = 1024 * 1024; + + public static String getFormattedDownloadProgress(long startTime, long fileSize, long bytesRead) { + long timeElapsed = System.currentTimeMillis() - startTime; + + double speed = ((double) bytesRead / timeElapsed) * 1000 / BYTES_IN_MB; + double percent = (double) bytesRead / fileSize * 100; + double downloadedMB = (double) bytesRead / BYTES_IN_MB; + double totalMB = (double) fileSize / BYTES_IN_MB; + double remainingMB = totalMB - downloadedMB; + + return format( + "%s of %s (%.2f%%), Speed: %.2f MB/sec, Time left: %s", + FileUtils.convertFileSize((long) downloadedMB * BYTES_IN_MB), + FileUtils.convertFileSize((long) totalMB * BYTES_IN_MB), + percent, + speed, + getTimeLeftFormattedString(speed, remainingMB)); + } + + private static String getTimeLeftFormattedString(double speed, double remainingMB) { + double timeLeftSec = speed > 0 ? remainingMB / speed : 0; + long hours = (long) (timeLeftSec / 3600); + long minutes = (long) ((timeLeftSec % 3600) / 60); + long seconds = (long) (timeLeftSec % 60); + + return format("%02d:%02d:%02d", hours, minutes, seconds); + } +} diff --git a/src/main/java/ee/carlrobert/codegpt/util/EditorUtils.java b/src/main/java/ee/carlrobert/codegpt/util/EditorUtils.java index 0f05a152..b1e4e770 100644 --- a/src/main/java/ee/carlrobert/codegpt/util/EditorUtils.java +++ b/src/main/java/ee/carlrobert/codegpt/util/EditorUtils.java @@ -63,7 +63,11 @@ public final class EditorUtils { var editor = getSelectedEditor(project); if (editor != null) { var selectionModel = editor.getSelectionModel(); - editor.getDocument().replaceString(selectionModel.getSelectionStart(), selectionModel.getSelectionEnd(), text); + editor.getDocument() + .replaceString( + selectionModel.getSelectionStart(), + selectionModel.getSelectionEnd(), + text); editor.getContentComponent().requestFocus(); selectionModel.removeSelection(); } diff --git a/src/main/java/ee/carlrobert/codegpt/util/OverlayUtils.java b/src/main/java/ee/carlrobert/codegpt/util/OverlayUtils.java index c07b09c8..9b2e1125 100644 --- a/src/main/java/ee/carlrobert/codegpt/util/OverlayUtils.java +++ b/src/main/java/ee/carlrobert/codegpt/util/OverlayUtils.java @@ -17,6 +17,7 @@ import com.intellij.openapi.ui.MessageDialogBuilder; import com.intellij.openapi.ui.MessageType; import com.intellij.openapi.ui.Messages; import com.intellij.openapi.ui.popup.Balloon; +import com.intellij.openapi.ui.popup.Balloon.Position; import com.intellij.openapi.ui.popup.JBPopupFactory; import com.intellij.ui.awt.RelativePoint; import com.intellij.ui.components.JBLabel; @@ -27,6 +28,7 @@ import ee.carlrobert.codegpt.conversations.ConversationsState; import ee.carlrobert.codegpt.indexes.FolderStructureTreePanel; import java.awt.Point; import java.awt.event.MouseEvent; +import javax.swing.JComponent; import org.jetbrains.annotations.NotNull; public class OverlayUtils { @@ -117,4 +119,12 @@ public class OverlayUtils { .createBalloon() .show(RelativePoint.fromScreen(locationOnScreen), Balloon.Position.above); } + + public static void showBalloon(String content, MessageType messageType, JComponent component) { + JBPopupFactory.getInstance() + .createHtmlTextBalloonBuilder(content, messageType, null) + .setFadeoutTime(2500) + .createBalloon() + .show(RelativePoint.getSouthOf(component), Position.below); + } } diff --git a/src/main/java/ee/carlrobert/codegpt/util/SwingUtils.java b/src/main/java/ee/carlrobert/codegpt/util/SwingUtils.java index e0cd01a3..8cdc579d 100644 --- a/src/main/java/ee/carlrobert/codegpt/util/SwingUtils.java +++ b/src/main/java/ee/carlrobert/codegpt/util/SwingUtils.java @@ -4,26 +4,33 @@ import static javax.swing.event.HyperlinkEvent.EventType.ACTIVATED; import com.intellij.ide.BrowserUtil; import com.intellij.util.ui.UI; -import java.awt.Component; -import java.awt.Desktop; import java.awt.Dimension; import java.awt.event.ActionEvent; -import java.io.IOException; import java.net.URISyntaxException; import javax.swing.AbstractAction; import javax.swing.BorderFactory; -import javax.swing.Box; import javax.swing.Icon; import javax.swing.JButton; import javax.swing.JComponent; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JTextArea; +import javax.swing.JTextPane; import javax.swing.KeyStroke; import javax.swing.event.HyperlinkEvent; +import javax.swing.event.HyperlinkListener; public class SwingUtils { + public static JTextPane createTextPane(HyperlinkListener listener) { + var textPane = new JTextPane(); + textPane.putClientProperty(JTextPane.HONOR_DISPLAY_PROPERTIES, true); + textPane.addHyperlinkListener(listener); + textPane.setContentType("text/html"); + textPane.setEditable(false); + return textPane; + } + public static JButton createIconButton(Icon icon) { var button = new JButton(icon); button.setBorder(BorderFactory.createEmptyBorder()); @@ -32,13 +39,6 @@ public class SwingUtils { return button; } - public static Box justifyLeft(Component component) { - Box box = Box.createHorizontalBox(); - box.add(component); - box.add(Box.createHorizontalGlue()); - return box; - } - public static void setEqualLabelWidths(JPanel firstPanel, JPanel secondPanel) { var firstLabel = firstPanel.getComponents()[0]; var secondLabel = secondPanel.getComponents()[0]; @@ -47,10 +47,6 @@ public class SwingUtils { } } - public static JPanel createPanel(JComponent component, String label) { - return createPanel(component, label, false); - } - public static JPanel createPanel(JComponent component, String label, boolean resizeX) { return UI.PanelFactory.panel(component) .withLabel(label) diff --git a/src/main/java/ee/carlrobert/codegpt/util/file/FileUtils.java b/src/main/java/ee/carlrobert/codegpt/util/file/FileUtils.java index a276dba9..7fd95700 100644 --- a/src/main/java/ee/carlrobert/codegpt/util/file/FileUtils.java +++ b/src/main/java/ee/carlrobert/codegpt/util/file/FileUtils.java @@ -6,16 +6,23 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.editor.Editor; import com.intellij.openapi.fileEditor.FileDocumentManager; +import com.intellij.openapi.progress.ProgressIndicator; import com.intellij.openapi.util.io.FileUtil; import com.intellij.openapi.vfs.VirtualFile; +import ee.carlrobert.codegpt.CodeGPTPlugin; import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; import java.io.Writer; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; +import java.text.DecimalFormat; import java.util.List; import java.util.Map; import java.util.Objects; @@ -40,6 +47,33 @@ public class FileUtils { } } + public static void copyFileWithProgress( + String fileName, + URL url, + long[] bytesRead, + long fileSize, + ProgressIndicator indicator) throws IOException { + FileUtils.tryCreateDirectory(CodeGPTPlugin.getLlamaModelsPath()); + + try ( + var readableByteChannel = Channels.newChannel(url.openStream()); + var fileOutputStream = new FileOutputStream( + CodeGPTPlugin.getLlamaModelsPath() + File.separator + fileName)) { + var buffer = ByteBuffer.allocateDirect(1024 * 10); + + while (readableByteChannel.read(buffer) != -1) { + if (indicator.isCanceled()) { + readableByteChannel.close(); + break; + } + buffer.flip(); + bytesRead[0] += fileOutputStream.getChannel().write(buffer); + buffer.clear(); + indicator.setFraction((double) bytesRead[0] / fileSize); + } + } + } + public static VirtualFile getEditorFile(@NotNull Editor editor) { return FileDocumentManager.getInstance().getFile(editor.getDocument()); } @@ -115,6 +149,19 @@ public class FileUtils { } } + public static String convertFileSize(long fileSizeInBytes) { + String[] units = {"B", "KB", "MB", "GB"}; + int unitIndex = 0; + double fileSize = fileSizeInBytes; + + while (fileSize >= 1024 && unitIndex < units.length - 1) { + fileSize /= 1024; + unitIndex++; + } + + return new DecimalFormat("#.##").format(fileSize) + " " + units[unitIndex]; + } + private static Optional> findFirstExtension( List languageFileExtensionMappings, String language) { diff --git a/src/main/resources/META-INF/plugin.xml b/src/main/resources/META-INF/plugin.xml index fddc2f3e..889141e7 100644 --- a/src/main/resources/META-INF/plugin.xml +++ b/src/main/resources/META-INF/plugin.xml @@ -23,6 +23,7 @@ + diff --git a/src/main/resources/icons/llama.svg b/src/main/resources/icons/llama.svg new file mode 100644 index 00000000..aec9c085 --- /dev/null +++ b/src/main/resources/icons/llama.svg @@ -0,0 +1,32 @@ + + + + + \ No newline at end of file diff --git a/src/main/resources/icons/llama_dark.svg b/src/main/resources/icons/llama_dark.svg new file mode 100644 index 00000000..50996c33 --- /dev/null +++ b/src/main/resources/icons/llama_dark.svg @@ -0,0 +1,32 @@ + + + + + \ No newline at end of file diff --git a/src/main/resources/icons/you.svg b/src/main/resources/icons/you.svg index 7b017a39..d4d48352 100644 --- a/src/main/resources/icons/you.svg +++ b/src/main/resources/icons/you.svg @@ -1,22 +1,15 @@ - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + diff --git a/src/main/resources/icons/you_dark.svg b/src/main/resources/icons/you_dark.svg new file mode 100644 index 00000000..7b017a39 --- /dev/null +++ b/src/main/resources/icons/you_dark.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/main/resources/messages/codegpt.properties b/src/main/resources/messages/codegpt.properties index a52481f3..d7e0336b 100644 --- a/src/main/resources/messages/codegpt.properties +++ b/src/main/resources/messages/codegpt.properties @@ -1,34 +1,57 @@ -action.editor.group.EditorActionGroup=CodeGPT +project.label=CodeGPT notification.group.name=CodeGPT notification group settings.displayName=CodeGPT: Settings -settingsConfigurable.section.userAuthentication.title=Authentication (Optional) -settingsConfigurable.section.userAuthentication.signIn.label=Sign In -settingsConfigurable.section.userInformation.title=User Information -settingsConfigurable.section.integration.apiKeyField.label=API key: -settingsConfigurable.section.integration.apiKeyField.comment=You can find your Secret API key in your User settings. -settingsConfigurable.section.integration.displayNameFieldLabel=Display name: -settingsConfigurable.section.service.title=Service Preference -settingsConfigurable.section.service.useOpenAIServiceRadioButtonLabel=Use OpenAI API -settingsConfigurable.section.service.useAzureServiceRadioButtonLabel=Use Azure OpenAI API -settingsConfigurable.section.service.useYouServiceRadioButtonLabel=Use You.com API (Free) -settingsConfigurable.section.service.useActiveDirectoryAuthenticationCheckBoxLabel=Use Azure active directory authentication -settingsConfigurable.section.service.openai.organizationField.label=Organization: -settingsConfigurable.section.service.openai.organizationField.comment=Useful when you are part of multiple organizations optional -settingsConfigurable.section.service.azure.resourceNameField.label=Resource name: -settingsConfigurable.section.service.azure.resourceNameField.comment=The name of your Azure OpenAI Resource. -settingsConfigurable.section.service.azure.deploymentIdField.label=Deployment ID: -settingsConfigurable.section.service.azure.deploymentIdField.comment=The name of your model deployment. You're required to first deploy a model before you can make calls. -settingsConfigurable.section.service.azure.apiVersionField.label=API version: -settingsConfigurable.section.service.azure.apiVersionField.comment=The API version to use for this operation. This follows the YYYY-MM-DD format. -settingsConfigurable.section.service.custom.hostField.label=Custom host: -settingsConfigurable.section.service.custom.hostField.comment=Example: http://localhost:8080 -settingsConfigurable.section.model.title=Model Preference -settingsConfigurable.section.model.selectionFieldLabel=Model: -settingsConfigurable.section.model.useChatCompletionRadioButtonLabel=Use chat completion -codebaseIndexing.task.title=Indexing codebase -codebaseIndexing.task.completed.title=Indexing completed -codebaseIndexing.task.completed.description=Creating embeddings completed -codebaseIndexing.task.failed.title=Unable to create embeddings +settings.openaiQuotaExceeded=OpenAI quota exceeded. +settingsConfigurable.displayName.label=Display name: +settingsConfigurable.service.label=Service: +settingsConfigurable.service.openai.apiKey.comment=You can find your Secret API key in your User settings. +settingsConfigurable.service.openai.organization.label=Organization: +settingsConfigurable.section.openai.organization.comment=Useful when you are part of multiple organizations optional +settingsConfigurable.service.azure.resourceName.label=Resource name: +settingsConfigurable.service.azure.resourceName.comment=The name of your Azure OpenAI resource. +settingsConfigurable.service.azure.deploymentId.label=Deployment ID: +settingsConfigurable.service.azure.deploymentId.comment=The name of your model deployment. You're required to first deploy a model before you can make calls. +settingsConfigurable.service.azure.apiVersion.label=API version: +settingsConfigurable.service.azure.apiVersion.comment=The API version to use for this operation. This follows the YYYY-MM-DD format. +settingsConfigurable.service.azure.bearerToken.label=Bearer token: +settingsConfigurable.service.azure.useApiKeyAuth.label=Use API key authentication +settingsConfigurable.service.azure.useActiveDirectoryAuth.label=Use Active Directory authentication +settingsConfigurable.service.you.email.label=Email address: +settingsConfigurable.service.you.password.label=Password: +settingsConfigurable.service.you.signIn.label=Sign In +settingsConfigurable.service.you.signOut.label=Sign Out +settingsConfigurable.service.you.displayResults.label=Display web search results +settingsConfigurable.service.you.authentication.title=Authentication (Optional) +settingsConfigurable.service.you.userInformation.title=User Information +settingsConfigurable.service.you.chatPreferences.title=Chat Preferences +settingsConfigurable.service.llama.modelPreferences.title=Model Preferences +settingsConfigurable.service.llama.serverPreferences.title=Server Preferences +settingsConfigurable.service.llama.modelSize.label=Model size: +settingsConfigurable.service.llama.quantization.label=Quantization: +settingsConfigurable.service.llama.quantization.comment=Quantization is a technique to reduce the computational and memory costs of running inference. Learn more +settingsConfigurable.service.llama.promptTemplate.label=Prompt template: +settingsConfigurable.service.llama.useCustomModel.label=Use custom model +settingsConfigurable.service.llama.customModelPath.label=Model path: +settingsConfigurable.service.llama.customModelPath.comment=Only .gguf files are supported +settingsConfigurable.service.llama.downloadModelLink.label=Download Model +settingsConfigurable.service.llama.cancelDownloadLink.label=Cancel Downloading +settingsConfigurable.service.llama.linkToModel.label=Link to model +settingsConfigurable.service.llama.contextSize.label=Context size: +settingsConfigurable.service.llama.contextSize.comment=The size of the prompt context. LLaMA models were built with a context of 2048, which will provide better results for longer input/inference +settingsConfigurable.service.llama.port.label=Port: +settingsConfigurable.service.llama.startServer.label=Start server +settingsConfigurable.service.llama.stopServer.label=Stop server +settingsConfigurable.service.llama.progress.stoppingServer=Stopping a server... +settingsConfigurable.service.llama.progress.startingServer=Starting a server... +settingsConfigurable.service.llama.progress.downloadingModel.title=Downloading Model +settingsConfigurable.service.llama.progress.downloadingModelIndicator.text=Downloading %s... +settingsConfigurable.service.llama.overlay.modelNotDownloaded.text=Model is not downloaded +settingsConfigurable.shared.authentication.title=Authentication +settingsConfigurable.shared.requestConfiguration.title=Request Configuration +settingsConfigurable.shared.apiKey.label=API key: +settingsConfigurable.shared.baseHost.label=Base host: +settingsConfigurable.shared.path.label=Path: +settingsConfigurable.shared.model.label=Model: configurationConfigurable.displayName=CodeGPT: Configuration configurationConfigurable.table.title=Editor Actions configurationConfigurable.table.emptyText=No actions configured @@ -45,13 +68,17 @@ configurationConfigurable.section.assistant.temperatureField.comment=The value o configurationConfigurable.section.assistant.maxTokensField.label=Max completion tokens: configurationConfigurable.section.assistant.maxTokensField.comment=The maximum capacity for completion. Must be between 100 and 2000 advancedSettingsConfigurable.displayName=CodeGPT: Advanced Settings -advancedSettingsConfigurable.section.proxy.title=HTTP/SOCKS Proxy -advancedSettingsConfigurable.section.proxy.typeComboBoxField.label=Proxy: -advancedSettingsConfigurable.section.proxy.hostField.label=Host name: -advancedSettingsConfigurable.section.proxy.portField.label=Port: -advancedSettingsConfigurable.section.proxy.authCheckBoxField.label=Proxy authentication -advancedSettingsConfigurable.section.proxy.usernameField.label=Username: -advancedSettingsConfigurable.section.proxy.passwordField.label=Password: +advancedSettingsConfigurable.proxy.title=HTTP/SOCKS Proxy +advancedSettingsConfigurable.proxy.typeComboBoxField.label=Proxy: +advancedSettingsConfigurable.proxy.hostField.label=Host name: +advancedSettingsConfigurable.proxy.portField.label=Port: +advancedSettingsConfigurable.proxy.authCheckBoxField.label=Proxy authentication +advancedSettingsConfigurable.proxy.usernameField.label=Username: +advancedSettingsConfigurable.proxy.passwordField.label=Password: +advancedSettingsConfigurable.connectionSettings.title=Connection Settings +advancedSettingsConfigurable.connectionSettings.connectionTimeout.label=Connection timeout (s): +advancedSettingsConfigurable.connectionSettings.readTimeout.label=Read timeout (s): +codebaseIndexing.task.title=Indexing codebase dialog.deleteConversation.title=Delete Conversation dialog.deleteConversation.description=Are you sure you want to delete this conversation? dialog.tokenLimitExceeded.title=Token Limit Exceeded @@ -62,11 +89,30 @@ editor.diff.title=CodeGPT Diff editor.diff.local.content.title=CodeGPT suggested code toolwindow.chat.editor.action.copy.title=Copy toolwindow.chat.editor.action.copy.description=Copy generated code +toolwindow.chat.editor.action.copy.success=Code copied! toolwindow.chat.editor.action.diff.title=Diff toolwindow.chat.editor.action.diff.description=Diff editor code against the generated one toolwindow.chat.editor.action.edit.title=Edit Source +toolwindow.chat.editor.action.disableEditing.title=Disable Editing toolwindow.chat.editor.action.edit.description=Edit generated code toolwindow.chat.editor.action.newFile.title=New File toolwindow.chat.editor.action.newFile.description=Create new file from generated code toolwindow.chat.editor.action.replaceSelection.title=Replace Selection toolwindow.chat.editor.action.replaceSelection.description=Replace main editor selected code +toolwindow.chat.response.action.reloadResponse.text=Reload Response +toolwindow.chat.response.action.reloadResponse.description=Reload response description +toolwindow.chat.response.action.deleteResponse.text=Delete Response +toolwindow.chat.response.action.deleteResponse.description=Delete response description +toolwindow.chat.youProCheckBox.text=Use GPT-4 model +toolwindow.chat.youProCheckBox.enable=Turn on for complex queries +toolwindow.chat.youProCheckBox.disable=Turn off for faster responses +toolwindow.chat.youProCheckBox.notAllowed=Enable by subscribing to YouPro plan +toolwindow.chat.textArea.emptyText=Ask me anything +service.openai.title=OpenAI Service +service.azure.title=Azure Service +service.you.title=You.com Service (Free, Cloud) +service.llama.title=LLaMA C/C++ Port (Free, Local) +validation.error.fieldRequired=This field is required. +validation.error.invalidEmail=The email you entered is invalid. +validation.error.mustBeNumber=Value must be number. +validation.error.mustBeBetweenZeroAndOne=Value must be between 0 and 1. \ No newline at end of file diff --git a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java index 65dc609c..d7427079 100644 --- a/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java +++ b/src/test/java/ee/carlrobert/codegpt/completions/DefaultCompletionRequestHandlerTest.java @@ -1,6 +1,8 @@ package ee.carlrobert.codegpt.completions; import static ee.carlrobert.codegpt.completions.CompletionRequestProvider.COMPLETION_SYSTEM_PROMPT; +import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA; +import static ee.carlrobert.llm.client.util.JSONUtil.e; import static ee.carlrobert.llm.client.util.JSONUtil.jsonArray; import static ee.carlrobert.llm.client.util.JSONUtil.jsonMap; import static ee.carlrobert.llm.client.util.JSONUtil.jsonMapResponse; @@ -10,14 +12,17 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; import com.intellij.testFramework.fixtures.BasePlatformTestCase; +import ee.carlrobert.codegpt.CodeGPTPlugin; import ee.carlrobert.codegpt.conversations.ConversationService; import ee.carlrobert.codegpt.conversations.message.Message; import ee.carlrobert.codegpt.credentials.AzureCredentialsManager; import ee.carlrobert.codegpt.credentials.OpenAICredentialsManager; import ee.carlrobert.codegpt.settings.configuration.ConfigurationState; import ee.carlrobert.codegpt.settings.state.AzureSettingsState; +import ee.carlrobert.codegpt.settings.state.LlamaSettingsState; import ee.carlrobert.codegpt.settings.state.OpenAISettingsState; import ee.carlrobert.codegpt.settings.state.SettingsState; +import ee.carlrobert.codegpt.settings.state.YouSettingsState; import ee.carlrobert.llm.client.http.LocalCallbackServer; import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange; import ee.carlrobert.llm.client.http.expectation.StreamExpectation; @@ -34,8 +39,11 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { super.setUp(); AzureCredentialsManager.getInstance().setApiKey("TEST_API_KEY"); OpenAICredentialsManager.getInstance().setApiKey("TEST_API_KEY"); + // FIXME OpenAISettingsState.getInstance().setBaseHost("http://127.0.0.1:8000"); AzureSettingsState.getInstance().setBaseHost("http://127.0.0.1:8000"); + YouSettingsState.getInstance().setBaseHost("http://127.0.0.1:8000"); + LlamaSettingsState.getInstance().setServerPort(8000); ConfigurationState.getInstance().setSystemPrompt(""); server = new LocalCallbackServer(8000); } @@ -46,7 +54,7 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { super.tearDown(); } - public void testChatCompletionCall() { + public void testOpenAIChatCompletionCall() { var message = new Message("TEST_PROMPT"); var conversation = ConversationService.getInstance().startConversation(); var requestHandler = new CompletionRequestHandler(); @@ -54,6 +62,8 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { var settings = SettingsState.getInstance(); settings.setUseOpenAIService(true); settings.setUseAzureService(false); + settings.setUseYouService(false); + settings.setUseLlamaService(false); expectStreamRequest("/v1/chat/completions", request -> { assertThat(request.getMethod()).isEqualTo("POST"); assertThat(request.getHeaders().get(AUTHORIZATION).get(0)).isEqualTo("Bearer TEST_API_KEY"); @@ -84,6 +94,8 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { var settings = SettingsState.getInstance(); settings.setUseOpenAIService(false); settings.setUseAzureService(true); + settings.setUseYouService(false); + settings.setUseLlamaService(false); var azureSettings = AzureSettingsState.getInstance(); azureSettings.setResourceName("TEST_RESOURCE_NAME"); azureSettings.setApiVersion("TEST_API_VERSION"); @@ -123,6 +135,97 @@ public class DefaultCompletionRequestHandlerTest extends BasePlatformTestCase { await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse())); } + public void testYouChatCompletionCall() { + var message = new Message("TEST_PROMPT"); + var conversation = ConversationService.getInstance().startConversation(); + conversation.addMessage(new Message("Ping", "Pong")); + var requestHandler = new CompletionRequestHandler(); + requestHandler.addRequestCompletedListener(message::setResponse); + var settings = SettingsState.getInstance(); + settings.setUseOpenAIService(false); + settings.setUseAzureService(false); + settings.setUseYouService(true); + settings.setUseLlamaService(false); + expectStreamRequest("/api/streamingSearch", request -> { + assertThat(request.getMethod()).isEqualTo("GET"); + assertThat(request.getUri().getPath()).isEqualTo("/api/streamingSearch"); + assertThat(request.getUri().getQuery()).isEqualTo( + "q=TEST_PROMPT&" + + "page=1&" + + "cfr=CodeGPT&" + + "count=10&" + + "safeSearch=WebPages,Translations,TimeZone,Computation,RelatedSearches&" + + "domain=youchat&" + + "chat=[{\"question\":\"Ping\",\"answer\":\"Pong\"}]&" + + "utm_source=ide&" + + "utm_medium=jetbrains&" + + "utm_campaign=" + CodeGPTPlugin.getVersion() + "&" + + "utm_content=CodeGPT"); + assertThat(request.getHeaders()) + .flatExtracting("Host", "Accept", "Connection", "User-agent", "Cookie") + .containsExactly("127.0.0.1:8000", + "text/event-stream", + "Keep-Alive", + "youide CodeGPT", + "safesearch_guest=Moderate; " + + "youpro_subscription=true; " + + "you_subscription=free; " + + "stytch_session=; " + + "ydc_stytch_session=; " + + "stytch_session_jwt=; " + + "ydc_stytch_session_jwt=; " + + "eg4=false; " + + "safesearch_9015f218b47611b62bbbaf61125cd2dac629e65c3d6f47573a2ec0e9b615c691=Moderate; " + + + "__cf_bm=aN2b3pQMH8XADeMB7bg9s1bJ_bfXBcCHophfOGRg6g0-1693601599-0-AWIt5Mr4Y3xQI4mIJ1lSf4+vijWKDobrty8OopDeBxY+NABe0MRFidF3dCUoWjRt8SVMvBZPI3zkOgcRs7Mz3yazd7f7c58HwW5Xg9jdBjNg;"); + return List.of( + jsonMapResponse("youChatToken", "Hel"), + jsonMapResponse("youChatToken", "lo"), + jsonMapResponse("youChatToken", "!")); + }); + + requestHandler.call(conversation, message, false); + + await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse())); + } + + public void testLlamaChatCompletionCall() { + var message = new Message("TEST_PROMPT"); + var conversation = ConversationService.getInstance().startConversation(); + conversation.addMessage(new Message("Ping", "Pong")); + var requestHandler = new CompletionRequestHandler(); + requestHandler.addRequestCompletedListener(message::setResponse); + var settings = SettingsState.getInstance(); + settings.setUseOpenAIService(false); + settings.setUseAzureService(false); + settings.setUseYouService(false); + settings.setUseLlamaService(true); + expectStreamRequest("/completion", request -> { + assertThat(request.getBody()) + .extracting( + "prompt", + "n_predict", + "stream") + .containsExactly( + LLAMA.buildPrompt( + COMPLETION_SYSTEM_PROMPT, + "TEST_PROMPT", + conversation.getMessages()), + 512, + true); + return List.of( + jsonMapResponse("content", "Hel"), + jsonMapResponse("content", "lo!"), + jsonMapResponse( + e("content", ""), + e("stop", true))); + }); + + requestHandler.call(conversation, message, false); + + await().atMost(5, SECONDS).until(() -> "Hello!".equals(message.getResponse())); + } + private void expectStreamRequest(String path, StreamHttpExchange exchange) { server.addExpectation(new StreamExpectation(path, exchange)); } diff --git a/src/test/java/ee/carlrobert/codegpt/completions/PromptTemplateTest.java b/src/test/java/ee/carlrobert/codegpt/completions/PromptTemplateTest.java new file mode 100644 index 00000000..3f1edbe4 --- /dev/null +++ b/src/test/java/ee/carlrobert/codegpt/completions/PromptTemplateTest.java @@ -0,0 +1,141 @@ +package ee.carlrobert.codegpt.completions; + +import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA; +import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML; +import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA; +import static ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA; +import static org.assertj.core.api.Assertions.assertThat; + +import ee.carlrobert.codegpt.conversations.message.Message; +import java.util.List; +import org.junit.Test; + +public class PromptTemplateTest { + + private static final String SYSTEM_PROMPT = "TEST_SYSTEM_PROMPT"; + private static final String USER_PROMPT = "TEST_USER_PROMPT"; + private static final List HISTORY = List.of( + new Message("TEST_PREV_PROMPT_1", "TEST_PREV_RESPONSE_1"), + new Message("TEST_PREV_PROMPT_2", "TEST_PREV_RESPONSE_2")); + + @Test + public void shouldBuildLlamaPromptWithHistory() { + var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY); + + assertThat(prompt).isEqualTo( + "<>TEST_SYSTEM_PROMPT<>\n" + + "[INST]TEST_PREV_PROMPT_1[/INST]\n" + + "TEST_PREV_RESPONSE_1\n" + + "[INST]TEST_PREV_PROMPT_2[/INST]\n" + + "TEST_PREV_RESPONSE_2\n" + + "[INST]TEST_USER_PROMPT[/INST]"); + } + + @Test + public void shouldBuildLlamaPromptWithoutHistory() { + var prompt = LLAMA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of()); + + assertThat(prompt).isEqualTo( + "<>TEST_SYSTEM_PROMPT<>\n" + + "[INST]TEST_USER_PROMPT[/INST]"); + } + + @Test + public void shouldBuildAlpacaPromptWithHistory() { + var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY); + + assertThat(prompt).isEqualTo( + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" + + "\n" + + "### Instruction\n" + + "TEST_PREV_PROMPT_1\n" + + "\n" + + "### Response:\n" + + "TEST_PREV_RESPONSE_1\n" + + "\n" + + "### Instruction\n" + + "TEST_PREV_PROMPT_2\n" + + "\n" + + "### Response:\n" + + "TEST_PREV_RESPONSE_2\n" + + "\n" + + "### Instruction\n" + + "TEST_USER_PROMPT\n" + + "\n" + + "### Response:\n"); + } + + @Test + public void shouldBuildAlpacaPromptWithoutHistory() { + var prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of()); + + assertThat(prompt).isEqualTo( + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" + + "\n" + + "### Instruction\n" + + "TEST_USER_PROMPT\n" + + "\n" + + "### Response:\n"); + } + + @Test + public void shouldBuildChatMLPromptWithHistory() { + var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY); + + assertThat(prompt).isEqualTo( + "<|im_start|>system\n" + + "TEST_SYSTEM_PROMPT<|im_end|>\n" + + "<|im_start|>user\n" + + "TEST_PREV_PROMPT_1<|im_end|>\n" + + "<|im_start|>assistant\n" + + "TEST_PREV_RESPONSE_1<|im_end|>\n" + + "<|im_start|>user\n" + + "TEST_PREV_PROMPT_2<|im_end|>\n" + + "<|im_start|>assistant\n" + + "TEST_PREV_RESPONSE_2<|im_end|>\n" + + "<|im_start|>user\n" + + "TEST_USER_PROMPT<|im_end|>" + ); + } + + @Test + public void shouldBuildChatMLPromptWithoutHistory() { + var prompt = CHAT_ML.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of()); + + assertThat(prompt).isEqualTo( + "<|im_start|>system\n" + + "TEST_SYSTEM_PROMPT<|im_end|>\n" + + "<|im_start|>user\n" + + "TEST_USER_PROMPT<|im_end|>"); + } + + @Test + public void shouldBuildToRAPromptWithHistory() { + var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY); + + assertThat(prompt).isEqualTo( + "<|user|>\n" + + "TEST_PREV_PROMPT_1\n" + + "<|assistant|>\n" + + "TEST_PREV_RESPONSE_1\n" + + "<|user|>\n" + + "TEST_PREV_PROMPT_2\n" + + "<|assistant|>\n" + + "TEST_PREV_RESPONSE_2\n" + + "<|user|>\n" + + "TEST_USER_PROMPT\n" + + "<|assistant|>" + ); + } + + @Test + public void shouldBuildToRAPromptWithoutHistory() { + var prompt = TORA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, List.of()); + + assertThat(prompt).isEqualTo( + "<|user|>\n" + + "TEST_USER_PROMPT\n" + + "<|assistant|>" + ); + } +}