Merge remote-tracking branch 'origin/master' into platform/2024.1

This commit is contained in:
Carl-Robert Linnupuu 2024-04-22 11:49:37 +03:00
commit e7ef58ad3d
76 changed files with 2199 additions and 1319 deletions

@ -1 +1 @@
Subproject commit 594fca3fefe27b8e95cfb1656eb0e160ad15a793
Subproject commit 7dbdba5690ca61b3ee8c92cfac8e7e251042e787

View file

@ -28,6 +28,7 @@ import ee.carlrobert.codegpt.EncodingManager;
import ee.carlrobert.codegpt.Icons;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.configuration.CommitMessageTemplate;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
@ -94,7 +95,10 @@ public class GenerateGitCommitMessageAction extends AnAction {
if (editor != null) {
((EditorEx) editor).setCaretVisible(false);
CompletionRequestService.getInstance()
.generateCommitMessageAsync(gitDiff, getEventListener(project, editor.getDocument()));
.generateCommitMessageAsync(
project.getService(CommitMessageTemplate.class).getSystemPrompt(),
gitDiff,
getEventListener(project, editor.getDocument()));
}
}

View file

@ -49,7 +49,11 @@ public class IncludeFilesInContextAction extends AnAction {
private static final Logger LOG = Logger.getInstance(IncludeFilesInContextAction.class);
public IncludeFilesInContextAction() {
super(CodeGPTBundle.get("action.includeFilesInContext.title"));
this("action.includeFilesInContext.title");
}
public IncludeFilesInContextAction(String customTitleKey) {
super(CodeGPTBundle.get(customTitleKey));
}
@Override
@ -93,11 +97,6 @@ public class IncludeFilesInContextAction extends AnAction {
}
private @Nullable FileCheckboxTree getCheckboxTree(DataContext dataContext) {
var psiElement = CommonDataKeys.PSI_ELEMENT.getData(dataContext);
if (psiElement != null) {
return new PsiElementCheckboxTree(psiElement);
}
var selectedVirtualFiles = VIRTUAL_FILE_ARRAY.getData(dataContext);
if (selectedVirtualFiles != null) {
return new VirtualFileCheckboxTree(selectedVirtualFiles);

View file

@ -11,6 +11,7 @@ import com.intellij.openapi.extensions.PluginId;
import com.intellij.openapi.project.Project;
import ee.carlrobert.codegpt.CodeGPTKeys;
import ee.carlrobert.codegpt.ReferencedFile;
import ee.carlrobert.codegpt.actions.IncludeFilesInContextAction;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.codegpt.toolwindow.chat.ChatToolWindowContentManager;
@ -75,6 +76,8 @@ public class EditorActionsUtil {
};
group.add(action);
});
group.addSeparator();
group.add(new IncludeFilesInContextAction("action.includeFileInContext.title"));
}
}

View file

@ -1,10 +1,10 @@
package ee.carlrobert.codegpt.completions;
import static ee.carlrobert.codegpt.completions.ConversationType.DEFAULT;
import static ee.carlrobert.codegpt.completions.ConversationType.FIX_COMPILE_ERRORS;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY;
import static ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
@ -24,8 +24,9 @@ import ee.carlrobert.codegpt.settings.IncludedFilesSettings;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettingsState;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceState;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
@ -59,7 +60,6 @@ import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -77,8 +77,6 @@ public class CompletionRequestProvider {
public static final String FIX_COMPILE_ERRORS_SYSTEM_PROMPT = getResourceContent(
"/prompts/fix-compile-errors.txt");
private static final Set<ConversationType> OPENAI_SYSTEM_CONVERSATION_TYPES = Set.of(
DEFAULT, FIX_COMPILE_ERRORS);
private final EncodingManager encodingManager = EncodingManager.getInstance();
private final Conversation conversation;
@ -118,16 +116,27 @@ public class CompletionRequestProvider {
public static Request buildCustomOpenAICompletionRequest(String system, String context) {
return buildCustomOpenAIChatCompletionRequest(
CustomServiceSettings.getCurrentState(),
ApplicationManager.getApplication().getService(CustomServiceState.class)
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage("system", system),
new OpenAIChatCompletionStandardMessage("user", context)),
true);
}
public static Request buildCustomOpenAICompletionRequest(String input) {
return buildCustomOpenAIChatCompletionRequest(
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
List.of(new OpenAIChatCompletionStandardMessage("user", input)),
true);
}
public static Request buildCustomOpenAILookupCompletionRequest(String context) {
return buildCustomOpenAIChatCompletionRequest(
CustomServiceSettings.getCurrentState(),
ApplicationManager.getApplication().getService(CustomServiceState.class)
.getChatCompletionSettings(),
List.of(
new OpenAIChatCompletionStandardMessage(
"system",
@ -157,7 +166,7 @@ public class CompletionRequestProvider {
}
var systemPrompt = conversationType == FIX_COMPILE_ERRORS
? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : ConfigurationSettings.getSystemPrompt();
? FIX_COMPILE_ERRORS_SYSTEM_PROMPT : ConfigurationSettings.getSystemPrompt();
var prompt = promptTemplate.buildPrompt(
systemPrompt,
@ -171,6 +180,7 @@ public class CompletionRequestProvider {
.setTop_p(settings.getTopP())
.setMin_p(settings.getMinP())
.setRepeat_penalty(settings.getRepeatPenalty())
.setStop(promptTemplate.getStopTokens())
.build();
}
@ -203,21 +213,21 @@ public class CompletionRequestProvider {
}
public Request buildCustomOpenAIChatCompletionRequest(
CustomServiceSettingsState customConfiguration,
CustomServiceChatCompletionSettingsState settings,
CallParameters callParameters) {
return buildCustomOpenAIChatCompletionRequest(
customConfiguration,
settings,
buildMessages(callParameters),
true);
}
private static Request buildCustomOpenAIChatCompletionRequest(
CustomServiceSettingsState customConfiguration,
CustomServiceChatCompletionSettingsState settings,
List<OpenAIChatCompletionMessage> messages,
boolean streamRequest) {
var requestBuilder = new Request.Builder().url(customConfiguration.getUrl().trim());
var requestBuilder = new Request.Builder().url(requireNonNull(settings.getUrl()).trim());
var credential = CredentialsStore.INSTANCE.getCredential(CUSTOM_SERVICE_API_KEY);
for (var entry : customConfiguration.getHeaders().entrySet()) {
for (var entry : settings.getHeaders().entrySet()) {
String value = entry.getValue();
if (credential != null && value.contains("$CUSTOM_SERVICE_API_KEY")) {
value = value.replace("$CUSTOM_SERVICE_API_KEY", credential);
@ -225,7 +235,7 @@ public class CompletionRequestProvider {
requestBuilder.addHeader(entry.getKey(), value);
}
var body = customConfiguration.getBody().entrySet().stream()
var body = settings.getBody().entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> {
@ -287,10 +297,13 @@ public class CompletionRequestProvider {
private List<OpenAIChatCompletionMessage> buildMessages(CallParameters callParameters) {
var message = callParameters.getMessage();
var messages = new ArrayList<OpenAIChatCompletionMessage>();
if (OPENAI_SYSTEM_CONVERSATION_TYPES.contains(callParameters.getConversationType())) {
String content = DEFAULT == callParameters.getConversationType()
? ConfigurationSettings.getSystemPrompt() : FIX_COMPILE_ERRORS_SYSTEM_PROMPT;
messages.add(new OpenAIChatCompletionStandardMessage("system", content));
if (callParameters.getConversationType() == ConversationType.DEFAULT) {
String systemPrompt = ConfigurationSettings.getCurrentState().getSystemPrompt();
messages.add(new OpenAIChatCompletionStandardMessage("system", systemPrompt));
}
if (callParameters.getConversationType() == ConversationType.FIX_COMPILE_ERRORS) {
messages.add(
new OpenAIChatCompletionStandardMessage("system", FIX_COMPILE_ERRORS_SYSTEM_PROMPT));
}
for (var prevMessage : conversation.getMessages()) {

View file

@ -29,6 +29,7 @@ import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionRequest;
import ee.carlrobert.llm.client.anthropic.completion.ClaudeCompletionStandardMessage;
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponse;
@ -55,6 +56,15 @@ public final class CompletionRequestService {
return ApplicationManager.getApplication().getService(CompletionRequestService.class);
}
public EventSource getCustomOpenAICompletionAsync(
Request customRequest,
CompletionEventListener<String> eventListener) {
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
return EventSources.createFactory(httpClient).newEventSource(
customRequest,
new OpenAITextCompletionEventSourceListener(eventListener));
}
public EventSource getCustomOpenAIChatCompletionAsync(
Request customRequest,
CompletionEventListener<String> eventListener) {
@ -76,7 +86,10 @@ public final class CompletionRequestService {
eventListener);
case CUSTOM_OPENAI -> getCustomOpenAIChatCompletionAsync(
requestProvider.buildCustomOpenAIChatCompletionRequest(
CustomServiceSettings.getCurrentState(),
ApplicationManager.getApplication()
.getService(CustomServiceSettings.class)
.getState()
.getChatCompletionSettings(),
callParameters),
eventListener);
case ANTHROPIC -> CompletionClientProvider.getClaudeClient().getCompletionAsync(
@ -93,21 +106,24 @@ public final class CompletionRequestService {
callParameters.getMessage(),
callParameters.getConversationType()),
eventListener);
default -> throw new IllegalArgumentException();
};
}
public EventSource getCodeCompletionAsync(
InfillRequestDetails requestDetails,
CompletionEventListener<String> eventListener) {
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
return switch (GeneralSettings.getCurrentState().getSelectedService()) {
case OPENAI -> CompletionClientProvider.getOpenAIClient()
.getCompletionAsync(
CodeCompletionRequestFactory.INSTANCE.buildOpenAIRequest(requestDetails),
CodeCompletionRequestFactory.buildOpenAIRequest(requestDetails),
eventListener);
case CUSTOM_OPENAI -> EventSources.createFactory(httpClient).newEventSource(
CodeCompletionRequestFactory.buildCustomRequest(requestDetails),
new OpenAITextCompletionEventSourceListener(eventListener));
case LLAMA_CPP -> CompletionClientProvider.getLlamaClient()
.getChatCompletionAsync(
CodeCompletionRequestFactory.INSTANCE.buildLlamaRequest(requestDetails),
CodeCompletionRequestFactory.buildLlamaRequest(requestDetails),
eventListener);
default ->
throw new IllegalArgumentException("Code completion not supported for selected service");
@ -115,13 +131,13 @@ public final class CompletionRequestService {
}
public void generateCommitMessageAsync(
String prompt,
String systemPrompt,
String gitDiff,
CompletionEventListener<String> eventListener) {
var configuration = ConfigurationSettings.getCurrentState();
var commitMessagePrompt = configuration.getCommitMessagePrompt();
var openaiRequest = new OpenAIChatCompletionRequest.Builder(List.of(
new OpenAIChatCompletionStandardMessage("system", commitMessagePrompt),
new OpenAIChatCompletionStandardMessage("user", prompt)))
new OpenAIChatCompletionStandardMessage("system", systemPrompt),
new OpenAIChatCompletionStandardMessage("user", gitDiff)))
.setModel(OpenAISettings.getCurrentState().getModel())
.build();
var selectedService = GeneralSettings.getCurrentState().getSelectedService();
@ -134,18 +150,18 @@ public final class CompletionRequestService {
var httpClient = CompletionClientProvider.getDefaultClientBuilder().build();
EventSources.createFactory(httpClient).newEventSource(
CompletionRequestProvider.buildCustomOpenAICompletionRequest(
commitMessagePrompt,
prompt),
systemPrompt,
gitDiff),
new OpenAIChatCompletionEventSourceListener(eventListener));
break;
case ANTHROPIC:
var anthropicSettings = AnthropicSettings.getCurrentState();
var claudeRequest = new ClaudeCompletionRequest();
claudeRequest.setSystem(commitMessagePrompt);
claudeRequest.setSystem(systemPrompt);
claudeRequest.setStream(true);
claudeRequest.setMaxTokens(configuration.getMaxTokens());
claudeRequest.setModel(anthropicSettings.getModel());
claudeRequest.setMessages(List.of(new ClaudeCompletionStandardMessage("user", prompt)));
claudeRequest.setMessages(List.of(new ClaudeCompletionStandardMessage("user", gitDiff)));
CompletionClientProvider.getClaudeClient()
.getCompletionAsync(claudeRequest, eventListener);
break;
@ -164,7 +180,7 @@ public final class CompletionRequestService {
} else {
promptTemplate = settings.getRemoteModelPromptTemplate();
}
var finalPrompt = promptTemplate.buildPrompt(commitMessagePrompt, prompt, List.of());
var finalPrompt = promptTemplate.buildPrompt(systemPrompt, gitDiff, List.of());
CompletionClientProvider.getLlamaClient().getChatCompletionAsync(
new LlamaCompletionRequest.Builder(finalPrompt)
.setN_predict(configuration.getMaxTokens())

View file

@ -43,16 +43,31 @@ public enum HuggingFaceModel {
WIZARD_CODER_PYTHON_13B_Q5(13, 5, "WizardCoder-Python-13B-V1.0-GGUF"),
WIZARD_CODER_PYTHON_34B_Q3(34, 3, "WizardCoder-Python-34B-V1.0-GGUF"),
WIZARD_CODER_PYTHON_34B_Q4(34, 4, "WizardCoder-Python-34B-V1.0-GGUF"),
WIZARD_CODER_PYTHON_34B_Q5(34, 5, "WizardCoder-Python-34B-V1.0-GGUF");
WIZARD_CODER_PYTHON_34B_Q5(34, 5, "WizardCoder-Python-34B-V1.0-GGUF"),
LLAMA_3_8B_IQ3_M(8, 3, "Meta-Llama-3-8B-Instruct-IQ3_M.gguf", "lmstudio-community"),
LLAMA_3_8B_Q4_K_M(8, 4, "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf", "lmstudio-community"),
LLAMA_3_8B_Q5_K_M(8, 5, "Meta-Llama-3-8B-Instruct-Q5_K_M.gguf", "lmstudio-community"),
LLAMA_3_8B_Q6_K(8, 6, "Meta-Llama-3-8B-Instruct-Q6_K.gguf", "lmstudio-community"),
LLAMA_3_8B_Q8_0(8, 8, "Meta-Llama-3-8B-Instruct-Q8_0.gguf", "lmstudio-community"),
LLAMA_3_70B_IQ1(70, 1, "Meta-Llama-3-70B-Instruct-IQ1_M.gguf", "lmstudio-community"),
LLAMA_3_70B_IQ2_XS(70, 2, "Meta-Llama-3-70B-Instruct-IQ2_XS.gguf", "lmstudio-community"),
LLAMA_3_70B_Q4_K_M(70, 4, "Meta-Llama-3-70B-Instruct-Q4_K_M.gguf", "lmstudio-community");
private final int parameterSize;
private final int quantization;
private final String modelName;
private final String user;
HuggingFaceModel(int parameterSize, int quantization, String modelName) {
this(parameterSize, quantization, modelName, "TheBloke");
}
HuggingFaceModel(int parameterSize, int quantization, String modelName, String user) {
this.parameterSize = parameterSize;
this.quantization = quantization;
this.modelName = modelName;
this.user = user;
}
public int getParameterSize() {
@ -68,13 +83,16 @@ public enum HuggingFaceModel {
}
public String getFileName() {
return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization));
if ("TheBloke".equals(user)) {
return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization));
}
return modelName;
}
public URL getFileURL() {
try {
return new URL(
format("https://huggingface.co/TheBloke/%s/resolve/main/%s", modelName, getFileName()));
"https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), getFileName()));
} catch (MalformedURLException ex) {
throw new RuntimeException(ex);
}
@ -82,12 +100,20 @@ public enum HuggingFaceModel {
public URL getHuggingFaceURL() {
try {
return new URL("https://huggingface.co/TheBloke/" + modelName);
return new URL("https://huggingface.co/%s/%s".formatted(user, getDirectory()));
} catch (MalformedURLException ex) {
throw new RuntimeException(ex);
}
}
private String getDirectory() {
if ("lmstudio-community".equals(user)) {
// Meta-Llama-3-8B-Instruct-Q4_K_M.gguf -> Meta-Llama-3-8B-Instruct-GGUF
return modelName.replaceFirst("-[^.-]+\\.gguf$", "-GGUF");
}
return modelName;
}
@Override
public String toString() {
return format("%d-bit precision", quantization);

View file

@ -82,7 +82,24 @@ public enum LlamaModel {
HuggingFaceModel.WIZARD_CODER_PYTHON_13B_Q5,
HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q3,
HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q4,
HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q5));
HuggingFaceModel.WIZARD_CODER_PYTHON_34B_Q5)),
LLAMA_3(
"Llama 3",
"Llama 3 is a family of large language models (LLMs), a collection of pretrained and "
+ "instruction tuned generative text models in 8 and 70B sizes. The Llama 3 instruction "
+ "tuned models are optimized for dialogue use cases and outperform many of the available"
+ " open source chat models on common industry benchmarks. Further, in developing these "
+ "models, we took great care to optimize helpfulness and safety.",
PromptTemplate.LLAMA_3,
List.of(
HuggingFaceModel.LLAMA_3_8B_IQ3_M,
HuggingFaceModel.LLAMA_3_8B_Q4_K_M,
HuggingFaceModel.LLAMA_3_8B_Q5_K_M,
HuggingFaceModel.LLAMA_3_8B_Q6_K,
HuggingFaceModel.LLAMA_3_8B_Q8_0,
HuggingFaceModel.LLAMA_3_70B_IQ1,
HuggingFaceModel.LLAMA_3_70B_IQ2_XS,
HuggingFaceModel.LLAMA_3_70B_Q4_K_M));
private final String label;
private final String description;

View file

@ -14,14 +14,17 @@ import com.intellij.openapi.Disposable;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.ui.MessageType;
import com.intellij.openapi.util.Key;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.ServerProgressPanel;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@ -32,65 +35,94 @@ public final class LlamaServerAgent implements Disposable {
private @Nullable OSProcessHandler makeProcessHandler;
private @Nullable OSProcessHandler startServerProcessHandler;
private ServerProgressPanel activeServerProgressPanel;
private boolean stoppedByUser;
public void startAgent(
LlamaServerStartupParams params,
ServerProgressPanel serverProgressPanel,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
this.activeServerProgressPanel = serverProgressPanel;
ApplicationManager.getApplication().invokeLater(() -> {
try {
serverProgressPanel.updateText(
stoppedByUser = false;
serverProgressPanel.displayText(
CodeGPTBundle.get("llamaServerAgent.buildingProject.description"));
makeProcessHandler = new OSProcessHandler(getMakeCommandLinde());
makeProcessHandler = new OSProcessHandler(
getMakeCommandLine(params.additionalBuildParameters()));
makeProcessHandler.addProcessListener(
getMakeProcessListener(params, serverProgressPanel, onSuccess, onServerTerminated));
getMakeProcessListener(params, onSuccess, onServerTerminated));
makeProcessHandler.startNotify();
} catch (ExecutionException e) {
throw new RuntimeException(e);
showServerError(e.getMessage(), onServerTerminated);
}
});
}
public void stopAgent() {
stoppedByUser = true;
if (makeProcessHandler != null) {
makeProcessHandler.destroyProcess();
}
if (startServerProcessHandler != null) {
startServerProcessHandler.destroyProcess();
}
}
public boolean isServerRunning() {
return startServerProcessHandler != null
return (makeProcessHandler != null
&& makeProcessHandler.isStartNotified()
&& !makeProcessHandler.isProcessTerminated())
|| (startServerProcessHandler != null
&& startServerProcessHandler.isStartNotified()
&& !startServerProcessHandler.isProcessTerminated();
&& !startServerProcessHandler.isProcessTerminated());
}
private ProcessListener getMakeProcessListener(
LlamaServerStartupParams params,
ServerProgressPanel serverProgressPanel,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
LOG.info("Building llama project");
return new ProcessAdapter() {
private final List<String> errorLines = new CopyOnWriteArrayList<>();
@Override
public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) {
if (ProcessOutputType.isStderr(outputType)) {
errorLines.add(event.getText());
return;
}
LOG.info(event.getText());
}
@Override
public void processTerminated(@NotNull ProcessEvent event) {
int exitCode = event.getExitCode();
LOG.info(format("Server build exited with code %d", exitCode));
if (stoppedByUser) {
onServerTerminated.accept(activeServerProgressPanel);
return;
}
if (exitCode != 0) {
showServerError(String.join(",", errorLines), onServerTerminated);
return;
}
try {
LOG.info("Booting up llama server");
serverProgressPanel.updateText(
activeServerProgressPanel.displayText(
CodeGPTBundle.get("llamaServerAgent.serverBootup.description"));
startServerProcessHandler = new OSProcessHandler.Silent(getServerCommandLine(params));
startServerProcessHandler.addProcessListener(
getProcessListener(params.port(), onSuccess, onServerTerminated));
getProcessListener(params.port(), onSuccess,
onServerTerminated));
startServerProcessHandler.startNotify();
} catch (ExecutionException ex) {
LOG.error("Unable to start llama server", ex);
throw new RuntimeException(ex);
showServerError(ex.getMessage(), onServerTerminated);
}
}
};
@ -99,27 +131,25 @@ public final class LlamaServerAgent implements Disposable {
private ProcessListener getProcessListener(
int port,
Runnable onSuccess,
Runnable onServerTerminated) {
Consumer<ServerProgressPanel> onServerTerminated) {
return new ProcessAdapter() {
private final ObjectMapper objectMapper = new ObjectMapper();
private final List<String> errorLines = new CopyOnWriteArrayList<>();
@Override
public void processTerminated(@NotNull ProcessEvent event) {
if (errorLines.isEmpty()) {
LOG.info(format("Server terminated with code %d", event.getExitCode()));
LOG.info(format("Server terminated with code %d", event.getExitCode()));
if (stoppedByUser) {
onServerTerminated.accept(activeServerProgressPanel);
} else {
LOG.info(String.join("", errorLines));
showServerError(String.join(",", errorLines), onServerTerminated);
}
onServerTerminated.run();
}
@Override
public void onTextAvailable(@NotNull ProcessEvent event, @NotNull Key outputType) {
if (ProcessOutputType.isStderr(outputType)) {
errorLines.add(event.getText());
return;
}
if (ProcessOutputType.isStdout(outputType)) {
@ -127,7 +157,8 @@ public final class LlamaServerAgent implements Disposable {
try {
var serverMessage = objectMapper.readValue(event.getText(), LlamaServerMessage.class);
if ("HTTP server listening".equals(serverMessage.message())) {
// hack
if ("HTTP server listening".equals(serverMessage.msg())) {
LOG.info("Server up and running!");
LlamaSettings.getCurrentState().setServerPort(port);
@ -141,11 +172,18 @@ public final class LlamaServerAgent implements Disposable {
};
}
private static GeneralCommandLine getMakeCommandLinde() {
private void showServerError(String errorText, Consumer<ServerProgressPanel> onServerTerminated) {
onServerTerminated.accept(activeServerProgressPanel);
LOG.info("Unable to start llama server:\n" + errorText);
OverlayUtil.showClosableBalloon(errorText, MessageType.ERROR, activeServerProgressPanel);
}
private static GeneralCommandLine getMakeCommandLine(List<String> additionalCompileParameters) {
GeneralCommandLine commandLine = new GeneralCommandLine().withCharset(StandardCharsets.UTF_8);
commandLine.setExePath("make");
commandLine.withWorkDirectory(CodeGPTPlugin.getLlamaSourcePath());
commandLine.addParameters("-j");
commandLine.addParameters(additionalCompileParameters);
commandLine.setRedirectErrorStream(false);
return commandLine;
}
@ -159,11 +197,16 @@ public final class LlamaServerAgent implements Disposable {
"-c", String.valueOf(params.contextLength()),
"--port", String.valueOf(params.port()),
"-t", String.valueOf(params.threads()));
commandLine.addParameters(params.additionalParameters());
commandLine.addParameters(params.additionalRunParameters());
commandLine.setRedirectErrorStream(false);
return commandLine;
}
public void setActiveServerProgressPanel(
ServerProgressPanel activeServerProgressPanel) {
this.activeServerProgressPanel = activeServerProgressPanel;
}
@Override
public void dispose() {
if (makeProcessHandler != null && !makeProcessHandler.isProcessTerminated()) {

View file

@ -3,5 +3,5 @@ package ee.carlrobert.codegpt.completions.llama;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
@JsonIgnoreProperties(ignoreUnknown = true)
public record LlamaServerMessage(String level, String message) {
public record LlamaServerMessage(String level, String msg) {
}

View file

@ -3,5 +3,6 @@ package ee.carlrobert.codegpt.completions.llama;
import java.util.List;
public record LlamaServerStartupParams(String modelPath, int contextLength, int threads, int port,
List<String> additionalParameters) {
List<String> additionalRunParameters,
List<String> additionalBuildParameters) {
}

View file

@ -1,5 +1,7 @@
package ee.carlrobert.codegpt.completions.llama;
import static java.util.Collections.emptyList;
import ee.carlrobert.codegpt.conversations.message.Message;
import java.util.List;
@ -55,6 +57,33 @@ public enum PromptTemplate {
.toString();
}
},
LLAMA_3("Llama 3", List.of("<|eot_id|>")) {
@Override
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
var prompt = new StringBuilder("<|begin_of_text|>");
if (systemPrompt != null && !systemPrompt.isBlank()) {
prompt
.append("<|start_header_id|>system<|end_header_id|>\n\n")
.append(systemPrompt)
.append("<|eot_id|>");
}
for (var message : history) {
prompt
.append("<|start_header_id|>user<|end_header_id|>\n\n")
.append(message.getPrompt())
.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")
.append(message.getResponse())
.append("<|eot_id|>");
}
return prompt
.append("<|start_header_id|>user<|end_header_id|>\n\n")
.append(userPrompt)
.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
.toString();
}
},
MIXTRAL_INSTRUCT("Mixtral Instruct") {
@Override
public String buildPrompt(String systemPrompt, String userPrompt, List<Message> history) {
@ -102,10 +131,10 @@ public enum PromptTemplate {
StringBuilder prompt = new StringBuilder();
prompt.append("""
Below is an instruction that describes a task. \
Write a response that appropriately completes the request.
Below is an instruction that describes a task. \
Write a response that appropriately completes the request.
""");
""");
for (Message message : history) {
prompt.append("### Instruction\n")
@ -160,13 +189,23 @@ public enum PromptTemplate {
};
private final String label;
private final List<String> stopTokens;
PromptTemplate(String label) {
this(label, emptyList());
}
PromptTemplate(String label, List<String> stopTokens) {
this.label = label;
this.stopTokens = stopTokens;
}
public abstract String buildPrompt(String systemPrompt, String userPrompt, List<Message> history);
public List<String> getStopTokens() {
return stopTokens;
}
@Override
public String toString() {
return label;

View file

@ -51,6 +51,9 @@ public class GeneralSettings implements PersistentStateComponent<GeneralSettings
if ("azure.chat.completion".equals(clientCode)) {
state.setSelectedService(ServiceType.AZURE);
}
if ("custom.openai.chat.completion".equals(clientCode)) {
state.setSelectedService(ServiceType.CUSTOM_OPENAI);
}
if ("llama.chat.completion".equals(clientCode)) {
state.setSelectedService(ServiceType.LLAMA_CPP);
var llamaSettings = LlamaSettings.getCurrentState();

View file

@ -12,8 +12,18 @@ import com.intellij.openapi.ui.ComboBox;
import com.intellij.ui.components.JBTextField;
import com.intellij.util.ui.FormBuilder;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.settings.service.ServiceSelectionForm;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettingsForm;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettingsForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceForm;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsForm;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
import ee.carlrobert.codegpt.settings.service.you.YouSettingsForm;
import java.awt.CardLayout;
import java.awt.Component;
import java.awt.Container;
@ -29,21 +39,30 @@ public class GeneralSettingsComponent {
private final JPanel mainPanel;
private final JBTextField displayNameField;
private final ComboBox<ServiceType> serviceComboBox;
private final ServiceSelectionForm serviceSelectionForm;
private final OpenAISettingsForm openAISettingsForm;
private final CustomServiceForm customConfigurationSettingsForm;
private final AnthropicSettingsForm anthropicSettingsForm;
private final AzureSettingsForm azureSettingsForm;
private final YouSettingsForm youSettingsForm;
private final LlamaSettingsForm llamaSettingsForm;
public GeneralSettingsComponent(Disposable parentDisposable, GeneralSettings settings) {
displayNameField = new JBTextField(settings.getState().getDisplayName(), 20);
serviceSelectionForm = new ServiceSelectionForm(parentDisposable);
openAISettingsForm = new OpenAISettingsForm(OpenAISettings.getCurrentState());
customConfigurationSettingsForm = new CustomServiceForm();
anthropicSettingsForm = new AnthropicSettingsForm(AnthropicSettings.getCurrentState());
azureSettingsForm = new AzureSettingsForm(AzureSettings.getCurrentState());
youSettingsForm = new YouSettingsForm(YouSettings.getCurrentState(), parentDisposable);
llamaSettingsForm = new LlamaSettingsForm(LlamaSettings.getCurrentState());
var cardLayout = new DynamicCardLayout();
var cards = new JPanel(cardLayout);
cards.add(serviceSelectionForm.getOpenAISettingsForm().getForm(), OPENAI.getCode());
cards.add(
serviceSelectionForm.getCustomConfigurationSettingsForm().getForm(),
CUSTOM_OPENAI.getCode());
cards.add(serviceSelectionForm.getAnthropicSettingsForm().getForm(), ANTHROPIC.getCode());
cards.add(serviceSelectionForm.getAzureSettingsForm().getForm(), AZURE.getCode());
cards.add(serviceSelectionForm.getYouSettingsForm(), YOU.getCode());
cards.add(serviceSelectionForm.getLlamaSettingsForm(), LLAMA_CPP.getCode());
cards.add(openAISettingsForm.getForm(), OPENAI.getCode());
cards.add(customConfigurationSettingsForm.getForm(), CUSTOM_OPENAI.getCode());
cards.add(anthropicSettingsForm.getForm(), ANTHROPIC.getCode());
cards.add(azureSettingsForm.getForm(), AZURE.getCode());
cards.add(youSettingsForm, YOU.getCode());
cards.add(llamaSettingsForm, LLAMA_CPP.getCode());
var serviceComboBoxModel = new DefaultComboBoxModel<ServiceType>();
serviceComboBoxModel.addAll(Arrays.stream(ServiceType.values()).toList());
serviceComboBox = new ComboBox<>(serviceComboBoxModel);
@ -63,6 +82,30 @@ public class GeneralSettingsComponent {
.getPanel();
}
public OpenAISettingsForm getOpenAISettingsForm() {
return openAISettingsForm;
}
public CustomServiceForm getCustomConfigurationSettingsForm() {
return customConfigurationSettingsForm;
}
public AnthropicSettingsForm getAnthropicSettingsForm() {
return anthropicSettingsForm;
}
public AzureSettingsForm getAzureSettingsForm() {
return azureSettingsForm;
}
public LlamaSettingsForm getLlamaSettingsForm() {
return llamaSettingsForm;
}
public YouSettingsForm getYouSettingsForm() {
return youSettingsForm;
}
public ServiceType getSelectedService() {
return serviceComboBox.getItem();
}
@ -79,10 +122,6 @@ public class GeneralSettingsComponent {
return displayNameField;
}
public ServiceSelectionForm getServiceSelectionForm() {
return serviceSelectionForm;
}
public String getDisplayName() {
return displayNameField.getText();
}
@ -91,6 +130,15 @@ public class GeneralSettingsComponent {
displayNameField.setText(displayName);
}
public void resetForms() {
openAISettingsForm.resetForm();
customConfigurationSettingsForm.resetForm();
anthropicSettingsForm.resetForm();
azureSettingsForm.resetForm();
youSettingsForm.resetForm();
llamaSettingsForm.resetForm();
}
static class DynamicCardLayout extends CardLayout {
@Override

View file

@ -18,7 +18,6 @@ import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettingsForm;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettingsForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
@ -61,17 +60,15 @@ public class GeneralSettingsConfigurable implements Configurable {
@Override
public boolean isModified() {
var settings = GeneralSettings.getCurrentState();
var serviceSelectionForm = component.getServiceSelectionForm();
return !component.getDisplayName().equals(settings.getDisplayName())
|| component.getSelectedService() != settings.getSelectedService()
|| OpenAISettings.getInstance().isModified(serviceSelectionForm.getOpenAISettingsForm())
|| CustomServiceSettings.getInstance()
.isModified(serviceSelectionForm.getCustomConfigurationSettingsForm())
|| AnthropicSettings.getInstance()
.isModified(serviceSelectionForm.getAnthropicSettingsForm())
|| AzureSettings.getInstance().isModified(serviceSelectionForm.getAzureSettingsForm())
|| YouSettings.getInstance().isModified(serviceSelectionForm.getYouSettingsForm())
|| LlamaSettings.getInstance().isModified(serviceSelectionForm.getLlamaSettingsForm());
|| OpenAISettings.getInstance().isModified(component.getOpenAISettingsForm())
|| component.getCustomConfigurationSettingsForm().isModified()
|| AnthropicSettings.getInstance().isModified(component.getAnthropicSettingsForm())
|| AzureSettings.getInstance().isModified(component.getAzureSettingsForm())
|| YouSettings.getInstance().isModified(component.getYouSettingsForm())
|| LlamaSettings.getInstance().isModified(component.getLlamaSettingsForm());
}
@Override
@ -80,14 +77,13 @@ public class GeneralSettingsConfigurable implements Configurable {
settings.setDisplayName(component.getDisplayName());
settings.setSelectedService(component.getSelectedService());
var serviceSelectionForm = component.getServiceSelectionForm();
var openAISettingsForm = serviceSelectionForm.getOpenAISettingsForm();
var openAISettingsForm = component.getOpenAISettingsForm();
applyOpenAISettings(openAISettingsForm);
applyCustomOpenAISettings(serviceSelectionForm.getCustomConfigurationSettingsForm());
applyAnthropicSettings(serviceSelectionForm.getAnthropicSettingsForm());
applyAzureSettings(serviceSelectionForm.getAzureSettingsForm());
applyYouSettings(serviceSelectionForm.getYouSettingsForm());
applyLlamaSettings(serviceSelectionForm.getLlamaSettingsForm());
applyCustomOpenAISettings(component.getCustomConfigurationSettingsForm());
applyAnthropicSettings(component.getAnthropicSettingsForm());
applyAzureSettings(component.getAzureSettingsForm());
applyYouSettings(component.getYouSettingsForm());
applyLlamaSettings(component.getLlamaSettingsForm());
var serviceChanged = component.getSelectedService() != settings.getSelectedService();
var modelChanged = !OpenAISettings.getCurrentState().getModel()
@ -109,7 +105,7 @@ public class GeneralSettingsConfigurable implements Configurable {
private void applyCustomOpenAISettings(CustomServiceForm form) {
CredentialsStore.INSTANCE.setCredential(CUSTOM_SERVICE_API_KEY, form.getApiKey());
CustomServiceSettings.getInstance().loadState(form.getCurrentState());
form.applyChanges();
}
private void applyLlamaSettings(LlamaSettingsForm form) {
@ -142,7 +138,7 @@ public class GeneralSettingsConfigurable implements Configurable {
var settings = GeneralSettings.getCurrentState();
component.setDisplayName(settings.getDisplayName());
component.setSelectedService(settings.getSelectedService());
component.getServiceSelectionForm().resetForms();
component.resetForms();
}
@Override

View file

@ -16,7 +16,6 @@ import com.intellij.ui.TitledSeparator;
import com.intellij.ui.ToolbarDecorator;
import com.intellij.ui.components.JBCheckBox;
import com.intellij.ui.components.JBLabel;
import com.intellij.ui.components.JBTextArea;
import com.intellij.ui.components.JBTextField;
import com.intellij.ui.components.fields.IntegerField;
import com.intellij.ui.table.JBTable;
@ -93,7 +92,7 @@ public class ConfigurationComponent {
maxTokensField.setColumns(12);
maxTokensField.setValue(configuration.getMaxTokens());
systemPromptTextArea = new JTextArea();
systemPromptTextArea = new JTextArea(3, 60);
if (configuration.getSystemPrompt().isBlank()) {
// for backward compatibility
systemPromptTextArea.setText(COMPLETION_SYSTEM_PROMPT);
@ -101,13 +100,12 @@ public class ConfigurationComponent {
systemPromptTextArea.setText(configuration.getSystemPrompt());
}
systemPromptTextArea.setLineWrap(true);
systemPromptTextArea.setWrapStyleWord(true);
systemPromptTextArea.setBorder(JBUI.Borders.empty(8, 4));
systemPromptTextArea.setColumns(60);
systemPromptTextArea.setRows(3);
commitMessagePromptTextArea = new JBTextArea(configuration.getCommitMessagePrompt(),
3, 60);
commitMessagePromptTextArea = new JTextArea(configuration.getCommitMessagePrompt(), 3, 60);
commitMessagePromptTextArea.setLineWrap(true);
commitMessagePromptTextArea.setWrapStyleWord(true);
commitMessagePromptTextArea.setBorder(JBUI.Borders.empty(8, 4));
checkForPluginUpdatesCheckBox = new JBCheckBox(
@ -247,20 +245,19 @@ public class ConfigurationComponent {
}
private JPanel createCommitMessageConfigurationForm() {
var formBuilder = FormBuilder.createFormBuilder();
addAssistantFormLabeledComponent(
formBuilder,
"configurationConfigurable.section.commitMessage.systemPromptField.label",
"configurationConfigurable.section.commitMessage.systemPromptField.comment",
JBUI.Panels
.simplePanel(commitMessagePromptTextArea)
.withBorder(JBUI.Borders.customLine(
JBUI.CurrentTheme.CustomFrameDecorations.separatorForeground())));
formBuilder.addVerticalGap(8);
var form = formBuilder.getPanel();
form.setBorder(JBUI.Borders.emptyLeft(16));
return form;
return FormBuilder.createFormBuilder()
.setFormLeftIndent(16)
.addLabeledComponent(
new JBLabel(CodeGPTBundle.get(
"configurationConfigurable.section.commitMessage.systemPromptField.label"))
.withBorder(JBUI.Borders.emptyLeft(2)),
UI.PanelFactory.panel(commitMessagePromptTextArea)
.resizeX(false)
.withComment(CommitMessageTemplate.Companion.getHtmlDescription())
.createPanel(),
true
)
.getPanel();
}
private ComponentValidator createTemperatureInputValidator(

View file

@ -123,13 +123,13 @@ public class ConfigurationState {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
if (!(o instanceof ConfigurationState that)) {
return false;
}
ConfigurationState that = (ConfigurationState) o;
return maxTokens == that.maxTokens
&& Double.compare(that.temperature, temperature) == 0
&& Double.compare(temperature, that.temperature) == 0
&& checkForPluginUpdates == that.checkForPluginUpdates
&& checkForNewScreenshots == that.checkForNewScreenshots
&& createNewChatOnEachAction == that.createNewChatOnEachAction
&& ignoreGitCommitTokenLimit == that.ignoreGitCommitTokenLimit
&& methodNameGenerationEnabled == that.methodNameGenerationEnabled
@ -143,7 +143,8 @@ public class ConfigurationState {
@Override
public int hashCode() {
return Objects.hash(systemPrompt, commitMessagePrompt, maxTokens, temperature,
checkForPluginUpdates, createNewChatOnEachAction, ignoreGitCommitTokenLimit,
methodNameGenerationEnabled, captureCompileErrors, autoFormattingEnabled, tableData);
checkForPluginUpdates, checkForNewScreenshots, createNewChatOnEachAction,
ignoreGitCommitTokenLimit, methodNameGenerationEnabled, captureCompileErrors,
autoFormattingEnabled, tableData);
}
}

View file

@ -1,68 +0,0 @@
package ee.carlrobert.codegpt.settings.service;
import com.intellij.openapi.Disposable;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings;
import ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettingsForm;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettings;
import ee.carlrobert.codegpt.settings.service.azure.AzureSettingsForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceForm;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.llama.form.LlamaSettingsForm;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettingsForm;
import ee.carlrobert.codegpt.settings.service.you.YouSettings;
import ee.carlrobert.codegpt.settings.service.you.YouSettingsForm;
public class ServiceSelectionForm {
private final OpenAISettingsForm openAISettingsForm;
private final CustomServiceForm customServiceForm;
private final AnthropicSettingsForm anthropicSettingsForm;
private final AzureSettingsForm azureSettingsForm;
private final LlamaSettingsForm llamaSettingsForm;
private final YouSettingsForm youSettingsForm;
public ServiceSelectionForm(Disposable parentDisposable) {
openAISettingsForm = new OpenAISettingsForm(OpenAISettings.getCurrentState());
customServiceForm = new CustomServiceForm(
CustomServiceSettings.getCurrentState());
anthropicSettingsForm = new AnthropicSettingsForm(AnthropicSettings.getCurrentState());
azureSettingsForm = new AzureSettingsForm(AzureSettings.getCurrentState());
youSettingsForm = new YouSettingsForm(YouSettings.getCurrentState(), parentDisposable);
llamaSettingsForm = new LlamaSettingsForm(LlamaSettings.getCurrentState());
}
public OpenAISettingsForm getOpenAISettingsForm() {
return openAISettingsForm;
}
public CustomServiceForm getCustomConfigurationSettingsForm() {
return customServiceForm;
}
public AnthropicSettingsForm getAnthropicSettingsForm() {
return anthropicSettingsForm;
}
public AzureSettingsForm getAzureSettingsForm() {
return azureSettingsForm;
}
public YouSettingsForm getYouSettingsForm() {
return youSettingsForm;
}
public LlamaSettingsForm getLlamaSettingsForm() {
return llamaSettingsForm;
}
public void resetForms() {
openAISettingsForm.resetForm();
customServiceForm.resetForm();
anthropicSettingsForm.resetForm();
azureSettingsForm.resetForm();
youSettingsForm.resetForm();
llamaSettingsForm.resetForm();
}
}

View file

@ -1,175 +0,0 @@
package ee.carlrobert.codegpt.settings.service.custom;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY;
import static ee.carlrobert.codegpt.ui.UIUtil.withEmptyLeftBorder;
import com.intellij.icons.AllIcons.General;
import com.intellij.ide.HelpTooltip;
import com.intellij.openapi.ui.ComboBox;
import com.intellij.openapi.ui.MessageType;
import com.intellij.ui.EnumComboBoxModel;
import com.intellij.ui.TitledSeparator;
import com.intellij.ui.components.JBLabel;
import com.intellij.ui.components.JBPasswordField;
import com.intellij.ui.components.JBTextField;
import com.intellij.util.ui.FormBuilder;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.completions.CallParameters;
import ee.carlrobert.codegpt.completions.CompletionRequestProvider;
import ee.carlrobert.codegpt.completions.CompletionRequestService;
import ee.carlrobert.codegpt.conversations.Conversation;
import ee.carlrobert.codegpt.conversations.message.Message;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import ee.carlrobert.codegpt.ui.UIUtil;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.awt.BorderLayout;
import java.awt.FlowLayout;
import java.net.MalformedURLException;
import java.net.URL;
import javax.swing.Box;
import javax.swing.JButton;
import javax.swing.JPanel;
import javax.swing.SwingUtilities;
import okhttp3.sse.EventSource;
import org.jetbrains.annotations.Nullable;
public class CustomServiceForm {
private final JBPasswordField apiKeyField;
private final JBTextField urlField;
private final CustomServiceFormTabbedPane tabbedPane;
private final JButton testConnectionButton;
private final JBLabel templateHelpText;
private final ComboBox<CustomServiceTemplate> templateComboBox;
public CustomServiceForm(CustomServiceSettingsState settings) {
apiKeyField = new JBPasswordField();
apiKeyField.setColumns(30);
apiKeyField.setText(CredentialsStore.INSTANCE.getCredential(CUSTOM_SERVICE_API_KEY));
urlField = new JBTextField(settings.getUrl(), 30);
tabbedPane = new CustomServiceFormTabbedPane(settings);
testConnectionButton = new JButton(CodeGPTBundle.get(
"settingsConfigurable.service.custom.openai.testConnection.label"));
testConnectionButton.addActionListener(e -> testConnection(getCurrentState()));
templateHelpText = new JBLabel(General.ContextHelp);
templateComboBox = new ComboBox<>(
new EnumComboBoxModel<>(CustomServiceTemplate.class));
templateComboBox.setSelectedItem(settings.getTemplate());
templateComboBox.addItemListener(e -> {
var template = (CustomServiceTemplate) e.getItem();
updateTemplateHelpTextTooltip(template);
urlField.setText(template.getUrl());
tabbedPane.setHeaders(template.getHeaders());
tabbedPane.setBody(template.getBody());
});
updateTemplateHelpTextTooltip(settings.getTemplate());
}
public JPanel getForm() {
var urlPanel = new JPanel(new BorderLayout(8, 0));
urlPanel.add(urlField, BorderLayout.CENTER);
urlPanel.add(testConnectionButton, BorderLayout.EAST);
var templateComboBoxWrapper = new JPanel(new FlowLayout(FlowLayout.LEADING, 0, 0));
templateComboBoxWrapper.add(templateComboBox);
templateComboBoxWrapper.add(Box.createHorizontalStrut(8));
templateComboBoxWrapper.add(templateHelpText);
var form = FormBuilder.createFormBuilder()
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.presetTemplate.label"),
templateComboBoxWrapper)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.shared.apiKey.label"),
apiKeyField)
.addComponentToRightColumn(
UIUtil.createComment("settingsConfigurable.service.custom.openai.apiKey.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.url.label"),
urlPanel)
.addComponent(tabbedPane)
.getPanel();
return FormBuilder.createFormBuilder()
.addComponent(new TitledSeparator(CodeGPTBundle.get("shared.configuration")))
.addComponent(withEmptyLeftBorder(form))
.addComponentFillVertically(new JPanel(), 0)
.getPanel();
}
public @Nullable String getApiKey() {
var apiKey = new String(apiKeyField.getPassword());
return apiKey.isEmpty() ? null : apiKey;
}
public CustomServiceSettingsState getCurrentState() {
var state = new CustomServiceSettingsState();
state.setUrl(urlField.getText());
state.setTemplate(templateComboBox.getItem());
state.setHeaders(tabbedPane.getHeaders());
state.setBody(tabbedPane.getBody());
return state;
}
public void resetForm() {
var state = CustomServiceSettings.getCurrentState();
apiKeyField.setText(CredentialsStore.INSTANCE.getCredential(CUSTOM_SERVICE_API_KEY));
urlField.setText(state.getUrl());
templateComboBox.setSelectedItem(state.getTemplate());
tabbedPane.setHeaders(state.getHeaders());
tabbedPane.setBody(state.getBody());
}
private void updateTemplateHelpTextTooltip(CustomServiceTemplate template) {
templateHelpText.setToolTipText(null);
try {
new HelpTooltip()
.setTitle(template.getName())
.setBrowserLink(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.linkToDocs"),
new URL(template.getDocsUrl()))
.installOn(templateHelpText);
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
private void testConnection(CustomServiceSettingsState customConfiguration) {
var conversation = new Conversation();
var request = new CompletionRequestProvider(conversation)
.buildCustomOpenAIChatCompletionRequest(
customConfiguration,
new CallParameters(conversation, new Message("Hello!")));
CompletionRequestService.getInstance()
.getCustomOpenAIChatCompletionAsync(request, new TestConnectionEventListener());
}
class TestConnectionEventListener implements CompletionEventListener<String> {
@Override
public void onMessage(String value, EventSource eventSource) {
if (value != null && !value.isEmpty()) {
SwingUtilities.invokeLater(() -> {
OverlayUtil.showBalloon(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.connectionSuccess"),
MessageType.INFO,
testConnectionButton);
eventSource.cancel();
});
}
}
@Override
public void onError(ErrorDetails error, Throwable ex) {
SwingUtilities.invokeLater(() ->
OverlayUtil.showBalloon(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.connectionFailed")
+ "\n\n"
+ error.getMessage(),
MessageType.ERROR,
testConnectionButton));
}
}
}

View file

@ -18,12 +18,13 @@ class CustomServiceFormTabbedPane extends JBTabbedPane {
private final JBTable headersTable;
private final JBTable bodyTable;
CustomServiceFormTabbedPane(CustomServiceSettingsState customConfiguration) {
CustomServiceFormTabbedPane(Map<String, String> headers, Map<String, ?> body) {
headersTable = new JBTable(
new DefaultTableModel(toArray(customConfiguration.getHeaders()),
new DefaultTableModel(toArray(headers),
new Object[]{"Key", "Value"}));
bodyTable = new JBTable(
new DefaultTableModel(toArray(customConfiguration.getBody()),
new DefaultTableModel(toArray(body),
new Object[]{"Key", "Value"}));
setTabComponentInsets(JBUI.insetsTop(8));
@ -46,11 +47,11 @@ class CustomServiceFormTabbedPane extends JBTabbedPane {
.collect(toMap(Entry::getKey, entry -> (String) entry.getValue()));
}
public void setBody(Map<String, ?> body) {
public void setBody(Map<String, Object> body) {
setTableData(bodyTable, body);
}
public Map<String, ?> getBody() {
public Map<String, Object> getBody() {
return getTableData(bodyTable);
}

View file

@ -1,45 +0,0 @@
package ee.carlrobert.codegpt.settings.service.custom;
import static ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.CUSTOM_SERVICE_API_KEY;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.PersistentStateComponent;
import com.intellij.openapi.components.State;
import com.intellij.openapi.components.Storage;
import ee.carlrobert.codegpt.credentials.CredentialsStore;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
@State(
name = "CodeGPT_CustomServiceSettings",
storages = @Storage("CodeGPT_CustomServiceSettings.xml"))
public class CustomServiceSettings implements PersistentStateComponent<CustomServiceSettingsState> {
private CustomServiceSettingsState state = new CustomServiceSettingsState();
@Override
@NotNull
public CustomServiceSettingsState getState() {
return state;
}
@Override
public void loadState(@NotNull CustomServiceSettingsState state) {
this.state = state;
}
public static CustomServiceSettingsState getCurrentState() {
return getInstance().getState();
}
public static CustomServiceSettings getInstance() {
return ApplicationManager.getApplication().getService(CustomServiceSettings.class);
}
public boolean isModified(CustomServiceForm form) {
return !form.getCurrentState().equals(state)
|| !StringUtils.equals(
form.getApiKey(),
CredentialsStore.INSTANCE.getCredential(CUSTOM_SERVICE_API_KEY));
}
}

View file

@ -1,69 +0,0 @@
package ee.carlrobert.codegpt.settings.service.custom;
import static ee.carlrobert.codegpt.settings.service.custom.CustomServiceTemplate.OPENAI;
import com.intellij.util.xmlb.annotations.OptionTag;
import ee.carlrobert.codegpt.util.MapConverter;
import java.util.Map;
import java.util.Objects;
public class CustomServiceSettingsState {
private String url = OPENAI.getUrl();
private Map<String, String> headers = OPENAI.getHeaders();
@OptionTag(converter = MapConverter.class)
private Map<String, ?> body = OPENAI.getBody();
private CustomServiceTemplate template = OPENAI;
public String getUrl() {
return url;
}
public void setUrl(String url) {
this.url = url;
}
public Map<String, String> getHeaders() {
return headers;
}
public void setHeaders(Map<String, String> headers) {
this.headers = headers;
}
public Map<String, ?> getBody() {
return body;
}
public void setBody(Map<String, ?> body) {
this.body = body;
}
public CustomServiceTemplate getTemplate() {
return template;
}
public void setTemplate(CustomServiceTemplate template) {
this.template = template;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
CustomServiceSettingsState that = (CustomServiceSettingsState) o;
return Objects.equals(url, that.url)
&& Objects.equals(headers, that.headers)
&& Objects.equals(body, that.body)
&& template == that.template;
}
@Override
public int hashCode() {
return Objects.hash(url, headers, body, template);
}
}

View file

@ -1,158 +0,0 @@
package ee.carlrobert.codegpt.settings.service.custom;
import java.util.HashMap;
import java.util.Map;
public enum CustomServiceTemplate {
// Cloud providers
ANYSCALE(
"Anyscale",
"https://docs.endpoints.anyscale.com/",
"https://api.endpoints.anyscale.com/v1/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(Map.of(
"model", "mistralai/Mixtral-8x7B-Instruct-v0.1",
"max_tokens", 1024))),
AZURE(
"Azure OpenAI",
"https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
"https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/chat/completions?api-version=2023-05-15",
getDefaultHeaders("api-key", "$CUSTOM_SERVICE_API_KEY"),
getDefaultBodyParams(Map.of())),
DEEP_INFRA(
"DeepInfra",
"https://deepinfra.com/docs/advanced/openai_api",
"https://api.deepinfra.com/v1/openai/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(Map.of(
"model", "meta-llama/Llama-2-70b-chat-hf",
"max_tokens", 1024))),
FIREWORKS(
"Fireworks",
"https://readme.fireworks.ai/reference/createchatcompletion",
"https://api.fireworks.ai/inference/v1/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(Map.of(
"model", "accounts/fireworks/models/llama-v2-7b-chat",
"max_tokens", 1024))),
GROQ(
"Groq",
"https://docs.api.groq.com/md/openai.oas.html",
"https://api.groq.com/openai/v1/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(Map.of(
"model", "codellama-34b",
"max_tokens", 1024))),
OPENAI(
"OpenAI",
"https://platform.openai.com/docs/api-reference/chat",
"https://api.openai.com/v1/chat/completions",
getDefaultHeaders("Authorization", "Bearer $CUSTOM_SERVICE_API_KEY"),
getDefaultBodyParams(Map.of(
"model", "gpt-4",
"max_tokens", 1024))),
PERPLEXITY(
"Perplexity AI",
"https://docs.perplexity.ai/reference/post_chat_completions",
"https://api.perplexity.ai/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(Map.of(
"model", "codellama",
"max_tokens", 1024))),
TOGETHER(
"Together AI",
"https://docs.together.ai/docs/openai-api-compatibility",
"https://api.together.xyz/v1/chat/completions",
getDefaultHeaders("Authorization", "Bearer $CUSTOM_SERVICE_API_KEY"),
getDefaultBodyParams(Map.of(
"model", "deepseek-ai/deepseek-coder-33b-instruct",
"max_tokens", 1024))),
// Local providers
OLLAMA(
"Ollama",
"https://github.com/ollama/ollama/blob/main/docs/openai.md",
"http://localhost:11434/v1/chat/completions",
getDefaultHeaders(),
getDefaultBodyParams(Map.of("model", "codellama"))),
LLAMA_CPP(
"LLaMA C/C++",
"https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md",
"http://localhost:8080/v1/chat/completions",
getDefaultHeaders(),
getDefaultBodyParams(Map.of()));
private final String name;
private final String docsUrl;
private final String url;
private final Map<String, String> headers;
private final Map<String, ?> body;
CustomServiceTemplate(
String name,
String docsUrl,
String url,
Map<String, String> headers,
Map<String, ?> body) {
this.name = name;
this.docsUrl = docsUrl;
this.url = url;
this.headers = headers;
this.body = body;
}
public String getName() {
return name;
}
public String getDocsUrl() {
return docsUrl;
}
public String getUrl() {
return url;
}
public Map<String, String> getHeaders() {
return headers;
}
public Map<String, ?> getBody() {
return body;
}
@Override
public String toString() {
return name;
}
private static Map<String, String> getDefaultHeadersWithAuthentication() {
return getDefaultHeaders("Authorization", "Bearer $CUSTOM_SERVICE_API_KEY");
}
private static Map<String, String> getDefaultHeaders() {
return getDefaultHeaders(Map.of());
}
private static Map<String, String> getDefaultHeaders(String key, String value) {
return getDefaultHeaders(Map.of(key, value));
}
private static Map<String, String> getDefaultHeaders(Map<String, String> additionalHeaders) {
var defaultHeaders = new HashMap<>(Map.of(
"Content-Type", "application/json",
"X-LLM-Application-Tag", "codegpt"));
defaultHeaders.putAll(additionalHeaders);
return defaultHeaders;
}
private static Map<String, ?> getDefaultBodyParams(Map<String, ?> additionalParams) {
var defaultParams = new HashMap<String, Object>(Map.of(
"stream", true,
"messages", "$OPENAI_MESSAGES",
"temperature", 0.1));
defaultParams.putAll(additionalParams);
return defaultParams;
}
}

View file

@ -23,6 +23,7 @@ public class LlamaSettingsState {
private int contextSize = 2048;
private int threads = 8;
private String additionalParameters = "";
private String additionalBuildParameters = "";
private int topK = 40;
private double topP = 0.9;
private double minP = 0.05;
@ -138,6 +139,14 @@ public class LlamaSettingsState {
this.additionalParameters = additionalParameters;
}
public String getAdditionalBuildParameters() {
return additionalBuildParameters;
}
public void setAdditionalBuildParameters(String additionalBuildParameters) {
this.additionalBuildParameters = additionalBuildParameters;
}
public int getTopK() {
return topK;
}
@ -220,6 +229,7 @@ public class LlamaSettingsState {
&& Objects.equals(baseHost, that.baseHost)
&& Objects.equals(serverPort, that.serverPort)
&& Objects.equals(additionalParameters, that.additionalParameters)
&& Objects.equals(additionalBuildParameters, that.additionalBuildParameters)
&& codeCompletionsEnabled == that.codeCompletionsEnabled
&& codeCompletionMaxTokens == that.codeCompletionMaxTokens;
}
@ -229,7 +239,7 @@ public class LlamaSettingsState {
return Objects.hash(runLocalServer, useCustomModel, customLlamaModelPath, huggingFaceModel,
localModelPromptTemplate, remoteModelPromptTemplate, localModelInfillPromptTemplate,
remoteModelInfillPromptTemplate, baseHost, serverPort, contextSize, threads,
additionalParameters, topK, topP, minP, repeatPenalty, codeCompletionsEnabled,
codeCompletionMaxTokens);
additionalParameters, additionalBuildParameters, topK, topP, minP, repeatPenalty,
codeCompletionsEnabled, codeCompletionMaxTokens);
}
}

View file

@ -57,6 +57,7 @@ public class LlamaServerPreferencesForm {
private final IntegerField maxTokensField;
private final IntegerField threadsField;
private final JBTextField additionalParametersField;
private final JBTextField additionalBuildParametersField;
private final ChatPromptTemplatePanel remotePromptTemplatePanel;
private final InfillPromptTemplatePanel infillPromptTemplatePanel;
@ -79,6 +80,9 @@ public class LlamaServerPreferencesForm {
additionalParametersField = new JBTextField(settings.getAdditionalParameters(), 30);
additionalParametersField.setEnabled(!serverRunning);
additionalBuildParametersField = new JBTextField(settings.getAdditionalBuildParameters(), 30);
additionalBuildParametersField.setEnabled(!serverRunning);
baseHostField = new JBTextField(settings.getBaseHost(), 30);
apiKeyField = new JBPasswordField();
apiKeyField.setColumns(30);
@ -124,6 +128,7 @@ public class LlamaServerPreferencesForm {
maxTokensField.setValue(state.getContextSize());
threadsField.setValue(state.getThreads());
additionalParametersField.setText(state.getAdditionalParameters());
additionalBuildParametersField.setText(state.getAdditionalBuildParameters());
remotePromptTemplatePanel.setPromptTemplate(state.getRemoteModelPromptTemplate()); // ?
infillPromptTemplatePanel.setPromptTemplate(state.getRemoteModelInfillPromptTemplate());
apiKeyField.setText(CredentialsStore.INSTANCE.getCredential(LLAMA_API_KEY));
@ -184,9 +189,17 @@ public class LlamaServerPreferencesForm {
createComment("settingsConfigurable.service.llama.threads.comment"))
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.llama.additionalParameters.label"),
additionalParametersField)
.addComponentToRightColumn(
createComment("settingsConfigurable.service.llama.additionalParameters.comment"))
additionalParametersField)
.addComponentToRightColumn(
createComment(
"settingsConfigurable.service.llama.additionalParameters.comment"))
.addLabeledComponent(
CodeGPTBundle.get(
"settingsConfigurable.service.llama.additionalBuildParameters.label"),
additionalBuildParametersField)
.addComponentToRightColumn(
createComment(
"settingsConfigurable.service.llama.additionalBuildParameters.comment"))
.addVerticalGap(4)
.addComponentFillVertically(new JPanel(), 0)
.getPanel()))
@ -196,6 +209,7 @@ public class LlamaServerPreferencesForm {
private JButton getServerButton(
LlamaServerAgent llamaServerAgent,
ServerProgressPanel serverProgressPanel) {
llamaServerAgent.setActiveServerProgressPanel(serverProgressPanel);
var serverRunning = llamaServerAgent.isServerRunning();
var serverButton = new JButton();
serverButton.setText(serverRunning
@ -218,7 +232,9 @@ public class LlamaServerPreferencesForm {
getContextSize(),
getThreads(),
getServerPort(),
getListOfAdditionalParameters()),
getListOfAdditionalParameters(),
getListOfAdditionalBuildParameters()
),
serverProgressPanel,
() -> {
setFormEnabled(false);
@ -227,12 +243,12 @@ public class LlamaServerPreferencesForm {
Actions.Checked,
SwingConstants.LEADING));
},
() -> {
(activeServerProgressPanel) -> {
setFormEnabled(true);
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.startServer.label"));
serverButton.setIcon(Actions.Execute);
serverProgressPanel.displayComponent(new JBLabel(
activeServerProgressPanel.displayComponent(new JBLabel(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.serverTerminated"),
Actions.Cancel,
SwingConstants.LEADING));
@ -282,7 +298,7 @@ public class LlamaServerPreferencesForm {
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.startServer.label"));
serverButton.setIcon(Actions.Execute);
progressPanel.updateText(
progressPanel.displayText(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.stoppingServer"));
}
@ -291,7 +307,7 @@ public class LlamaServerPreferencesForm {
serverButton.setText(
CodeGPTBundle.get("settingsConfigurable.service.llama.stopServer.label"));
serverButton.setIcon(Actions.Suspend);
progressPanel.startProgress(
progressPanel.displayText(
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.startingServer"));
}
@ -301,6 +317,7 @@ public class LlamaServerPreferencesForm {
maxTokensField.setEnabled(enabled);
threadsField.setEnabled(enabled);
additionalParametersField.setEnabled(enabled);
additionalBuildParametersField.setEnabled(enabled);
}
public boolean isRunLocalServer() {
@ -337,9 +354,20 @@ public class LlamaServerPreferencesForm {
public List<String> getListOfAdditionalParameters() {
return Arrays.stream(additionalParametersField.getText().split(","))
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
}
public String getAdditionalBuildParameters() {
return additionalBuildParametersField.getText();
}
public List<String> getListOfAdditionalBuildParameters() {
return Arrays.stream(additionalBuildParametersField.getText().split(","))
.map(String::trim)
.filter(s -> !s.isBlank())
.toList();
}
public PromptTemplate getPromptTemplate() {

View file

@ -41,6 +41,7 @@ public class LlamaSettingsForm extends JPanel {
state.setContextSize(llamaServerPreferencesForm.getContextSize());
state.setThreads(llamaServerPreferencesForm.getThreads());
state.setAdditionalParameters(llamaServerPreferencesForm.getAdditionalParameters());
state.setAdditionalBuildParameters(llamaServerPreferencesForm.getAdditionalBuildParameters());
var modelPreferencesForm = llamaServerPreferencesForm.getLlamaModelPreferencesForm();
state.setCustomLlamaModelPath(modelPreferencesForm.getCustomLlamaModelPath());

View file

@ -8,20 +8,15 @@ import javax.swing.JPanel;
public class ServerProgressPanel extends JPanel {
private final JBLabel label = new JBLabel();
private final AsyncProcessIcon loadingSpinner = new AsyncProcessIcon("sign_in_spinner");
public ServerProgressPanel() {
setVisible(false);
add(new AsyncProcessIcon("sign_in_spinner"));
add(label);
}
public void startProgress(String text) {
setVisible(true);
updateText(text);
}
public void updateText(String text) {
public void displayText(String text) {
label.setText(text);
removeAll();
add(loadingSpinner);
add(label);
revalidate();
repaint();
}
public void displayComponent(JComponent component) {

View file

@ -82,7 +82,10 @@ public class ModelComboBoxAction extends ComboBoxAction {
actionGroup.addSeparator("Custom OpenAI Service");
actionGroup.add(createModelAction(
CUSTOM_OPENAI,
CustomServiceSettings.getCurrentState().getTemplate().getName(),
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getTemplate()
.getProviderName(),
Icons.OpenAI,
presentation));
actionGroup.addSeparator();
@ -150,9 +153,11 @@ public class ModelComboBoxAction extends ComboBoxAction {
break;
case CUSTOM_OPENAI:
templatePresentation.setIcon(Icons.OpenAI);
templatePresentation.setText(CustomServiceSettings.getCurrentState()
.getTemplate()
.getName());
templatePresentation.setText(
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getTemplate()
.getProviderName());
break;
case ANTHROPIC:
templatePresentation.setIcon(Icons.Anthropic);

View file

@ -149,4 +149,13 @@ public class OverlayUtil {
.createBalloon()
.show(RelativePoint.getSouthOf(component), Position.below);
}
public static void showClosableBalloon(String content, MessageType messageType,
JComponent component) {
JBPopupFactory.getInstance()
.createHtmlTextBalloonBuilder(content, messageType, null)
.setCloseButtonEnabled(true)
.createBalloon()
.show(RelativePoint.getSouthOf(component), Position.below);
}
}

View file

@ -1,43 +0,0 @@
package ee.carlrobert.codegpt.util;
import com.intellij.openapi.application.Application;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.project.ProjectManager;
import com.intellij.openapi.wm.IdeFocusManager;
import com.intellij.openapi.wm.IdeFrame;
import org.jetbrains.annotations.Nullable;
public class ApplicationUtil {
private ApplicationUtil() {
}
public static boolean isUnitTestingMode() {
Application app = ApplicationManager.getApplication();
return app != null && app.isUnitTestMode();
}
@Nullable
public static Project findCurrentProject() {
IdeFrame frame = IdeFocusManager.getGlobalInstance().getLastFocusedFrame();
Project project = frame != null ? frame.getProject() : null;
if (isValidProject(project)) {
return project;
}
return findProjectFromOpenProjects();
}
private static Project findProjectFromOpenProjects() {
for (Project project : ProjectManager.getInstance().getOpenProjects()) {
if (isValidProject(project)) {
return project;
}
}
return null;
}
private static boolean isValidProject(@Nullable Project project) {
return project != null && !project.isDisposed() && !project.isDefault();
}
}

View file

@ -1,38 +0,0 @@
package ee.carlrobert.codegpt.util;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.intellij.util.xmlb.Converter;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public abstract class BaseConverter<T> extends Converter<T> {
private final TypeReference<T> typeReference;
private final ObjectMapper objectMapper = new ObjectMapper()
.registerModule(new Jdk8Module())
.registerModule(new JavaTimeModule());
public BaseConverter(TypeReference<T> typeReference) {
this.typeReference = typeReference;
}
public @Nullable T fromString(@NotNull String value) {
try {
return objectMapper.readValue(value, typeReference);
} catch (JsonProcessingException e) {
throw new RuntimeException("Unable to deserialize conversations", e);
}
}
public @Nullable String toString(@NotNull T value) {
try {
return objectMapper.writeValueAsString(value);
} catch (JsonProcessingException e) {
throw new RuntimeException("Unable to serialize conversations", e);
}
}
}

View file

@ -1,152 +0,0 @@
package ee.carlrobert.codegpt.util;
import static java.lang.String.format;
import com.intellij.codeInsight.daemon.DaemonCodeAnalyzer;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.application.PathManager;
import com.intellij.openapi.command.WriteCommandAction;
import com.intellij.openapi.editor.Document;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.editor.EditorFactory;
import com.intellij.openapi.editor.EditorKind;
import com.intellij.openapi.fileEditor.FileDocumentManager;
import com.intellij.openapi.fileEditor.FileEditor;
import com.intellij.openapi.fileEditor.FileEditorManager;
import com.intellij.openapi.fileEditor.TextEditor;
import com.intellij.openapi.fileEditor.impl.FileEditorManagerImpl;
import com.intellij.openapi.project.Project;
import com.intellij.psi.PsiDocumentManager;
import com.intellij.psi.codeStyle.CodeStyleManager;
import com.intellij.testFramework.LightVirtualFile;
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public final class EditorUtil {
private EditorUtil() {
}
public static Editor createEditor(@NotNull Project project, String fileExtension, String code) {
var timestamp = DateTimeFormatter.ofPattern("yyyyMMddHHmmss").format(LocalDateTime.now());
var fileName = "temp_" + timestamp + fileExtension;
var lightVirtualFile = new LightVirtualFile(
format("%s/%s", PathManager.getTempPath(), fileName),
code);
var existingDocument = FileDocumentManager.getInstance().getDocument(lightVirtualFile);
var document = existingDocument != null
? existingDocument
: EditorFactory.getInstance().createDocument(code);
disableHighlighting(project, document);
return EditorFactory.getInstance().createEditor(
document,
project,
lightVirtualFile,
true,
EditorKind.MAIN_EDITOR);
}
public static void updateEditorDocument(Editor editor, String content) {
var document = editor.getDocument();
var application = ApplicationManager.getApplication();
Runnable updateDocumentRunnable = () -> application.runWriteAction(() ->
WriteCommandAction.runWriteCommandAction(editor.getProject(), () -> {
document.replaceString(0, document.getTextLength(), content);
editor.getComponent().repaint();
editor.getComponent().revalidate();
}));
if (application.isUnitTestMode()) {
application.invokeAndWait(updateDocumentRunnable);
} else {
application.invokeLater(updateDocumentRunnable);
}
}
public static boolean hasSelection(@Nullable Editor editor) {
return editor != null && editor.getSelectionModel().hasSelection();
}
public static @Nullable Editor getSelectedEditor(@NotNull Project project) {
FileEditorManager editorManager = FileEditorManager.getInstance(project);
return editorManager != null ? editorManager.getSelectedTextEditor() : null;
}
public static @Nullable String getSelectedEditorSelectedText(@NotNull Project project) {
var selectedEditor = EditorUtil.getSelectedEditor(project);
if (selectedEditor != null) {
return selectedEditor.getSelectionModel().getSelectedText();
}
return null;
}
public static boolean isSelectedEditor(Editor editor) {
Project project = editor.getProject();
if (project != null && !project.isDisposed()) {
FileEditorManager editorManager = FileEditorManager.getInstance(project);
if (editorManager == null) {
return false;
}
if (editorManager instanceof FileEditorManagerImpl) {
Editor current = ((FileEditorManagerImpl) editorManager).getSelectedTextEditor(true);
return current != null && current.equals(editor);
}
FileEditor current = editorManager.getSelectedEditor();
return current instanceof TextEditor && editor.equals(((TextEditor) current).getEditor());
}
return false;
}
public static boolean isMainEditorTextSelected(@NotNull Project project) {
return hasSelection(getSelectedEditor(project));
}
public static void replaceMainEditorSelection(@NotNull Project project, @NotNull String text) {
var application = ApplicationManager.getApplication();
application.invokeLater(() ->
application.runWriteAction(() -> WriteCommandAction.runWriteCommandAction(project, () -> {
var editor = getSelectedEditor(project);
if (editor != null) {
var selectionModel = editor.getSelectionModel();
int startOffset = selectionModel.getSelectionStart();
int endOffset = selectionModel.getSelectionEnd();
var document = editor.getDocument();
document.replaceString(startOffset, endOffset, text);
if (ConfigurationSettings.getCurrentState().isAutoFormattingEnabled()) {
reformatDocument(project, document, startOffset, endOffset);
}
editor.getContentComponent().requestFocus();
selectionModel.removeSelection();
}
})));
}
public static void reformatDocument(
@NotNull Project project,
@NotNull Document document,
int startOffset,
int endOffset) {
var psiDocumentManager = PsiDocumentManager.getInstance(project);
psiDocumentManager.commitDocument(document);
var psiFile = psiDocumentManager.getPsiFile(document);
if (psiFile != null) {
CodeStyleManager.getInstance(project)
.reformatText(psiFile, startOffset, endOffset);
}
}
public static void disableHighlighting(@NotNull Project project, Document document) {
var psiFile = PsiDocumentManager.getInstance(project).getPsiFile(document);
if (psiFile != null) {
DaemonCodeAnalyzer.getInstance(project).setHighlightingEnabled(psiFile, false);
}
}
}

View file

@ -1,11 +0,0 @@
package ee.carlrobert.codegpt.util;
import com.fasterxml.jackson.core.type.TypeReference;
import java.util.Map;
public class MapConverter extends BaseConverter<Map<String, Object>> {
public MapConverter() {
super(new TypeReference<>() {});
}
}

View file

@ -1,47 +0,0 @@
package ee.carlrobert.codegpt.util;
import com.vladsch.flexmark.html.HtmlRenderer;
import com.vladsch.flexmark.parser.Parser;
import com.vladsch.flexmark.util.data.MutableDataSet;
import ee.carlrobert.codegpt.toolwindow.chat.ResponseNodeRenderer;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class MarkdownUtil {
private MarkdownUtil() {
}
/**
* Splits a given string into a list of strings where each element is either a code block
* surrounded by triple backticks or a non-code block text.
*
* @param inputMarkdown The input markdown formatted string to be split.
* @return A list of strings where each element is a code block or a non-code block text from the
* input string.
*/
public static List<String> splitCodeBlocks(String inputMarkdown) {
List<String> result = new ArrayList<>();
Pattern pattern = Pattern.compile("(?s)```.*?```");
Matcher matcher = pattern.matcher(inputMarkdown);
int start = 0;
while (matcher.find()) {
result.add(inputMarkdown.substring(start, matcher.start()));
result.add(matcher.group());
start = matcher.end();
}
result.add(inputMarkdown.substring(start));
return result.stream().filter(item -> !item.isBlank()).toList();
}
public static String convertMdToHtml(String message) {
MutableDataSet options = new MutableDataSet();
var document = Parser.builder(options).build().parse(message);
return HtmlRenderer.builder(options)
.nodeRendererFactory(new ResponseNodeRenderer.Factory())
.build()
.render(document);
}
}

View file

@ -1,23 +0,0 @@
package ee.carlrobert.codegpt.util.file;
public class FileExtensionLanguageDetails {
private String extension;
private String value;
public String getExtension() {
return extension;
}
public void setExtension(String extension) {
this.extension = extension;
}
public String getValue() {
return value;
}
public void setValue(String value) {
this.value = value;
}
}

View file

@ -1,195 +0,0 @@
package ee.carlrobert.codegpt.util.file;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.fileEditor.FileDocumentManager;
import com.intellij.openapi.progress.ProgressIndicator;
import com.intellij.openapi.vfs.VirtualFile;
import ee.carlrobert.codegpt.CodeGPTPlugin;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.Writer;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.text.DecimalFormat;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.jetbrains.annotations.NotNull;
public class FileUtil {
private FileUtil() {
}
private static final Logger LOG = Logger.getInstance(FileUtil.class);
public static File createFile(String directoryPath, String fileName, String fileContent) {
try {
tryCreateDirectory(directoryPath);
return Files.writeString(
Path.of(directoryPath, fileName),
fileContent,
StandardOpenOption.CREATE).toFile();
} catch (IOException e) {
throw new RuntimeException("Failed to create file", e);
}
}
public static void copyFileWithProgress(
String fileName,
URL url,
long[] bytesRead,
long fileSize,
ProgressIndicator indicator) throws IOException {
FileUtil.tryCreateDirectory(CodeGPTPlugin.getLlamaModelsPath());
try (
var readableByteChannel = Channels.newChannel(url.openStream());
var fileOutputStream = new FileOutputStream(
CodeGPTPlugin.getLlamaModelsPath() + File.separator + fileName)) {
var buffer = ByteBuffer.allocateDirect(1024 * 10);
while (readableByteChannel.read(buffer) != -1) {
if (indicator.isCanceled()) {
readableByteChannel.close();
break;
}
buffer.flip();
bytesRead[0] += fileOutputStream.getChannel().write(buffer);
buffer.clear();
indicator.setFraction((double) bytesRead[0] / fileSize);
}
}
}
public static VirtualFile getEditorFile(@NotNull Editor editor) {
return FileDocumentManager.getInstance().getFile(editor.getDocument());
}
public static void tryCreateDirectory(String directoryPath) {
try {
if (!com.intellij.openapi.util.io.FileUtil.exists(directoryPath)) {
if (!com.intellij.openapi.util.io.FileUtil.createDirectory(
Path.of(directoryPath).toFile())) {
throw new IOException("Failed to create directory: " + directoryPath);
}
}
} catch (IOException e) {
throw new RuntimeException("Failed to create directory", e);
}
}
public static String getFileExtension(String filename) {
Pattern pattern = Pattern.compile("[^.]+$");
Matcher matcher = pattern.matcher(filename);
if (matcher.find()) {
return matcher.group();
}
return "";
}
public static Map.Entry<String, String> findLanguageExtensionMapping(String language) {
var defaultValue = Map.entry("Text", ".txt");
var mapper = new ObjectMapper();
List<FileExtensionLanguageDetails> extensionToLanguageMappings;
List<LanguageFileExtensionDetails> languageToExtensionMappings;
try {
extensionToLanguageMappings = mapper.readValue(
getResourceContent("/fileExtensionLanguageMappings.json"), new TypeReference<>() {
});
languageToExtensionMappings = mapper.readValue(
getResourceContent("/languageFileExtensionMappings.json"), new TypeReference<>() {
});
} catch (JsonProcessingException e) {
LOG.error("Unable to extract file extension", e);
return defaultValue;
}
return findFirstExtension(languageToExtensionMappings, language)
.or(() -> extensionToLanguageMappings.stream()
.filter(it -> it.getExtension().equalsIgnoreCase(language))
.findFirst()
.flatMap(it -> findFirstExtension(languageToExtensionMappings, it.getValue()))
).orElse(defaultValue);
}
public static boolean isUtf8File(String filePath) {
var path = Paths.get(filePath);
try (var reader = Files.newBufferedReader(path)) {
int c = reader.read();
if (c >= 0) {
reader.transferTo(Writer.nullWriter());
}
return true;
} catch (Exception e) {
return false;
}
}
public static String getImageMediaType(String fileName) {
var fileExtension = getFileExtension(fileName);
return switch (fileExtension) {
case "png" -> "image/png";
case "jpg", "jpeg" -> "image/jpeg";
default -> throw new IllegalArgumentException("Unsupported image type: " + fileExtension);
};
}
public static String getResourceContent(String name) {
try (var stream = Objects.requireNonNull(FileUtil.class.getResourceAsStream(name))) {
return new String(stream.readAllBytes(), StandardCharsets.UTF_8);
} catch (IOException e) {
throw new RuntimeException("Unable to read resource", e);
}
}
public static String convertFileSize(long fileSizeInBytes) {
String[] units = {"B", "KB", "MB", "GB"};
int unitIndex = 0;
double fileSize = fileSizeInBytes;
while (fileSize >= 1024 && unitIndex < units.length - 1) {
fileSize /= 1024;
unitIndex++;
}
return new DecimalFormat("#.##").format(fileSize) + " " + units[unitIndex];
}
public static String convertLongValue(long value) {
if (value >= 1_000_000) {
return value / 1_000_000 + "M";
}
if (value >= 1_000) {
return value / 1_000 + "K";
}
return String.valueOf(value);
}
private static Optional<Map.Entry<String, String>> findFirstExtension(
List<LanguageFileExtensionDetails> languageFileExtensionMappings,
String language) {
return languageFileExtensionMappings.stream()
.filter(item -> language.equalsIgnoreCase(item.getName()))
.findFirst()
.map(it -> Map.entry(it.getName(), it.getExtensions().get(0)));
}
}

View file

@ -1,34 +0,0 @@
package ee.carlrobert.codegpt.util.file;
import java.util.List;
public class LanguageFileExtensionDetails {
private String name;
private String type;
private List<String> extensions;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public List<String> getExtensions() {
return extensions;
}
public void setExtensions(List<String> extensions) {
this.extensions = extensions;
}
}

View file

@ -3,9 +3,9 @@ package ee.carlrobert.codegpt
import com.intellij.notification.NotificationAction
import com.intellij.notification.NotificationType
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import com.intellij.openapi.startup.ProjectActivity
import com.intellij.openapi.util.Disposer
import ee.carlrobert.codegpt.actions.editor.EditorActionsUtil
import ee.carlrobert.codegpt.completions.you.YouUserManager
import ee.carlrobert.codegpt.completions.you.auth.AuthenticationHandler
@ -19,10 +19,14 @@ import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.service.you.YouSettings
import ee.carlrobert.codegpt.toolwindow.chat.ui.textarea.AttachImageNotifier
import ee.carlrobert.codegpt.ui.OverlayUtil
import io.ktor.util.*
import java.nio.file.Paths
import kotlin.io.path.absolutePathString
class CodeGPTProjectActivity : ProjectActivity {
private val watchExtensions = listOf("jpg", "jpeg", "png")
override suspend fun execute(project: Project) {
EditorActionsUtil.refreshActions()
CredentialsStore.loadAll()
@ -34,14 +38,13 @@ class CodeGPTProjectActivity : ProjectActivity {
if (!ApplicationManager.getApplication().isUnitTestMode
&& ConfigurationSettings.getCurrentState().isCheckForNewScreenshots
) {
val pathToWatch = Paths.get(System.getProperty("user.home"), "Desktop")
val fileWatcher = FileWatcher(pathToWatch)
fileWatcher.watch {
if (listOf("jpg", "jpeg", "png").contains(it.extension)) {
showImageAttachmentNotification(project, it.absolutePath)
val desktopPath = Paths.get(System.getProperty("user.home"), "Desktop")
project.service<FileWatcher>()
.watch(desktopPath) {
if (watchExtensions.contains(it.extension.lowercase())) {
showImageAttachmentNotification(project, desktopPath.resolve(it).absolutePathString())
}
}
}
Disposer.register(project, fileWatcher)
}
}
@ -97,4 +100,4 @@ class CodeGPTProjectActivity : ProjectActivity {
})
.notify(project)
}
}
}

View file

@ -1,29 +1,34 @@
package ee.carlrobert.codegpt
import com.intellij.openapi.Disposable
import org.apache.commons.io.monitor.FileAlterationListenerAdaptor
import org.apache.commons.io.monitor.FileAlterationMonitor
import org.apache.commons.io.monitor.FileAlterationObserver
import java.io.File
import com.intellij.openapi.components.Service
import java.nio.file.FileSystems
import java.nio.file.Path
import java.nio.file.StandardWatchEventKinds.ENTRY_CREATE
import java.nio.file.WatchKey
import kotlin.concurrent.thread
class FileWatcher(private val pathToWatch: Path) : Disposable {
private val fileMonitor =
FileAlterationMonitor(500, FileAlterationObserver(pathToWatch.toFile()))
@Service(Service.Level.PROJECT)
class FileWatcher : Disposable {
fun watch(onFileCreated: (File) -> Unit) {
val observer = FileAlterationObserver(pathToWatch.toFile())
observer.addListener(object : FileAlterationListenerAdaptor() {
override fun onFileCreate(file: File) {
onFileCreated(file)
private var fileMonitor: Thread? = null
fun watch(pathToWatch: Path, onFileCreated: (Path) -> Unit) {
val watchService = FileSystems.getDefault().newWatchService()
pathToWatch.register(watchService, ENTRY_CREATE) // watch for new files
fileMonitor = thread {
var key: WatchKey
while ((watchService.take().also { key = it }) != null) {
for (event in key.pollEvents()) {
onFileCreated(event.context() as Path)
}
key.reset()
}
})
fileMonitor.addObserver(observer)
fileMonitor.start()
}
}
override fun dispose() {
fileMonitor.stop()
fileMonitor?.interrupt()
}
}
}

View file

@ -2,11 +2,12 @@ package ee.carlrobert.codegpt.actions
import com.intellij.openapi.actionSystem.ActionUpdateThread
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.components.service
import com.intellij.openapi.project.DumbAwareAction
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.ServiceType.LLAMA_CPP
import ee.carlrobert.codegpt.settings.service.ServiceType.OPENAI
import ee.carlrobert.codegpt.settings.service.ServiceType.*
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
@ -14,12 +15,16 @@ abstract class CodeCompletionFeatureToggleActions(
private val enableFeatureAction: Boolean
) : DumbAwareAction() {
override fun actionPerformed(e: AnActionEvent) {
GeneralSettings.getCurrentState().selectedService
.takeIf { it in listOf(OPENAI, LLAMA_CPP) }
.takeIf { it in listOf(OPENAI, CUSTOM_OPENAI, LLAMA_CPP) }
?.also { selectedService ->
if (OPENAI == selectedService) {
OpenAISettings.getCurrentState().isCodeCompletionsEnabled = enableFeatureAction
} else if (CUSTOM_OPENAI == selectedService) {
service<CustomServiceSettings>().state.codeCompletionSettings.codeCompletionsEnabled =
enableFeatureAction
} else {
LlamaSettings.getCurrentState().isCodeCompletionsEnabled = enableFeatureAction
}
@ -31,7 +36,7 @@ abstract class CodeCompletionFeatureToggleActions(
val codeCompletionEnabled = isCodeCompletionsEnabled(selectedService)
e.presentation.isEnabled = codeCompletionEnabled != enableFeatureAction
e.presentation.isVisible =
e.presentation.isEnabled && listOf(OPENAI, LLAMA_CPP).contains(
e.presentation.isEnabled && listOf(OPENAI, CUSTOM_OPENAI, LLAMA_CPP).contains(
selectedService
)
}
@ -43,6 +48,7 @@ abstract class CodeCompletionFeatureToggleActions(
private fun isCodeCompletionsEnabled(serviceType: ServiceType): Boolean {
return when (serviceType) {
OPENAI -> OpenAISettings.getCurrentState().isCodeCompletionsEnabled
CUSTOM_OPENAI -> service<CustomServiceSettings>().state.codeCompletionSettings.codeCompletionsEnabled
LLAMA_CPP -> LlamaSettings.getCurrentState().isCodeCompletionsEnabled
else -> false
}

View file

@ -1,13 +1,26 @@
package ee.carlrobert.codegpt.codecompletions
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.completions.llama.LlamaModel
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
import ee.carlrobert.codegpt.settings.configuration.Placeholder
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettingsState
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
import ee.carlrobert.llm.client.openai.completion.request.OpenAITextCompletionRequest
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.Request
import okhttp3.RequestBody.Companion.toRequestBody
import java.nio.charset.StandardCharsets
object CodeCompletionRequestFactory {
@JvmStatic
fun buildOpenAIRequest(details: InfillRequestDetails): OpenAITextCompletionRequest {
return OpenAITextCompletionRequest.Builder(details.prefix)
.setSuffix(details.suffix)
@ -17,6 +30,35 @@ object CodeCompletionRequestFactory {
.build()
}
@JvmStatic
fun buildCustomRequest(details: InfillRequestDetails): Request {
val settings = service<CustomServiceSettings>().state.codeCompletionSettings
val requestBuilder = Request.Builder().url(settings.url!!)
val credential = getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY)
for (entry in settings.headers.entries) {
var value = entry.value
if (credential != null && value.contains("\$CUSTOM_SERVICE_API_KEY")) {
value = value.replace("\$CUSTOM_SERVICE_API_KEY", credential)
}
requestBuilder.addHeader(entry.key, value)
}
val transformedBody = settings.body.entries.associate { (key, value) ->
key to transformValue(value, settings.infillTemplate, details)
}
try {
val requestBody = ObjectMapper()
.writerWithDefaultPrettyPrinter()
.writeValueAsString(transformedBody)
.toByteArray(StandardCharsets.UTF_8)
.toRequestBody("application/json".toMediaType())
return requestBuilder.post(requestBody).build()
} catch (e: JsonProcessingException) {
throw RuntimeException(e)
}
}
@JvmStatic
fun buildLlamaRequest(details: InfillRequestDetails): LlamaCompletionRequest {
val settings = LlamaSettings.getCurrentState()
val promptTemplate = getLlamaInfillPromptTemplate(settings)
@ -38,4 +80,18 @@ object CodeCompletionRequestFactory {
}
return LlamaModel.findByHuggingFaceModel(settings.huggingFaceModel).infillPromptTemplate
}
private fun transformValue(
value: Any,
template: InfillPromptTemplate,
details: InfillRequestDetails
): Any {
if (value !is String) return value
return when (value) {
"$" + Placeholder.FIM_PROMPT -> template.buildPrompt(details.prefix, details.suffix)
"$" + Placeholder.PREFIX -> details.prefix
"$" + Placeholder.SUFFIX -> details.suffix
else -> value
}
}
}

View file

@ -5,6 +5,7 @@ import com.intellij.codeInsight.inline.completion.InlineCompletionProvider
import com.intellij.codeInsight.inline.completion.InlineCompletionProviderID
import com.intellij.codeInsight.inline.completion.InlineCompletionRequest
import com.intellij.codeInsight.inline.completion.elements.InlineCompletionGrayTextElement
import com.intellij.notification.NotificationType
import com.intellij.codeInsight.inline.completion.suggestion.InlineCompletionSingleSuggestion
import com.intellij.codeInsight.inline.completion.suggestion.InlineCompletionSuggestionUpdateManager
import com.intellij.codeInsight.inline.completion.suggestion.InlineCompletionSuggestionUpdateManager.UpdateResult
@ -12,12 +13,17 @@ import com.intellij.codeInsight.inline.completion.suggestion.InlineCompletionSug
import com.intellij.codeInsight.inline.completion.suggestion.InlineCompletionSuggestionUpdateManager.UpdateResult.Invalidated
import com.intellij.codeInsight.inline.completion.suggestion.InlineCompletionVariant
import com.intellij.openapi.application.EDT
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.thisLogger
import ee.carlrobert.codegpt.CodeGPTKeys
import ee.carlrobert.codegpt.completions.CompletionRequestService
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import ee.carlrobert.codegpt.ui.OverlayUtil.showNotification
import ee.carlrobert.llm.client.openai.completion.ErrorDetails
import ee.carlrobert.llm.completion.CompletionEventListener
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.awaitClose
@ -29,9 +35,8 @@ import okhttp3.sse.EventSource
import java.util.concurrent.atomic.AtomicReference
class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
companion object {
private val LOG = Logger.getInstance(CodeGPTInlineCompletionProvider::class.java)
private val logger = thisLogger()
}
private val currentCall = AtomicReference<EventSource>(null)
@ -44,7 +49,7 @@ class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
override suspend fun getSuggestion(request: InlineCompletionRequest): InlineCompletionSingleSuggestion {
if (request.editor.project == null) {
LOG.error("Could not find project")
logger.error("Could not find project")
return InlineCompletionSingleSuggestion.build(elements = emptyFlow())
}
@ -55,12 +60,14 @@ class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
currentCall.set(
CompletionRequestService.getInstance().getCodeCompletionAsync(
infillRequest,
CodeCompletionEventListener(infillRequest) {
CodeCompletionEventListener {
val inlineText = it.takeWhile { message -> message != '\n' }.toString()
request.editor.putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, inlineText)
launch {
try {
trySend(InlineCompletionGrayTextElement(it))
trySend(InlineCompletionGrayTextElement(inlineText))
} catch (e: Exception) {
LOG.error("Failed to send inline completion suggestion", e)
logger.error("Failed to send inline completion suggestion", e)
}
}
}
@ -74,6 +81,7 @@ class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
val selectedService = GeneralSettings.getCurrentState().selectedService
val codeCompletionsEnabled = when (selectedService) {
ServiceType.OPENAI -> OpenAISettings.getCurrentState().isCodeCompletionsEnabled
ServiceType.CUSTOM_OPENAI -> service<CustomServiceSettings>().state.codeCompletionSettings.codeCompletionsEnabled
ServiceType.LLAMA_CPP -> LlamaSettings.getCurrentState().isCodeCompletionsEnabled
else -> false
}
@ -85,24 +93,28 @@ class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
}
internal class CodeCompletionEventListener(
private val requestDetails: InfillRequestDetails,
private val completed: (String) -> Unit
private val completed: (StringBuilder) -> Unit
) : CompletionEventListener<String> {
override fun onMessage(message: String?, eventSource: EventSource?) {
if (message != null && message.contains('\n')) {
eventSource?.cancel()
}
}
override fun onComplete(messageBuilder: StringBuilder) {
// TODO: https://youtrack.jetbrains.com/issue/CPP-38312/CLion-crashes-around-every-10-minutes-of-work
/*val processedOutput = CodeCompletionParserFactory
.getParserForFileExtension(requestDetails.fileExtension)
.parse(
requestDetails.prefix,
requestDetails.suffix,
messageBuilder.toString()
)*/
val output =
if (messageBuilder.contains("\n"))
messageBuilder.substring(0, messageBuilder.indexOf("\n"))
else messageBuilder.toString()
completed(output)
completed(messageBuilder)
}
override fun onCancelled(messageBuilder: StringBuilder) {
completed(messageBuilder)
}
override fun onError(error: ErrorDetails, ex: Throwable) {
if (ex.message == null || (ex.message != null && ex.message != "Canceled")) {
showNotification(error.message, NotificationType.ERROR)
logger.error(error.message, ex)
}
}
}

View file

@ -28,4 +28,4 @@ enum class InfillPromptTemplate(val label: String, val stopTokens: List<String>?
override fun toString(): String {
return label
}
}
}

View file

@ -4,11 +4,10 @@ import com.intellij.codeInsight.inline.completion.InlineCompletionRequest
import com.intellij.openapi.editor.Document
import com.intellij.openapi.util.TextRange
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.util.file.FileUtil
import kotlin.math.max
import kotlin.math.min
class InfillRequestDetails(val prefix: String, val suffix: String, val fileExtension: String) {
class InfillRequestDetails(val prefix: String, val suffix: String) {
companion object {
private const val MAX_OFFSET = 10_000
private const val MAX_PROMPT_TOKENS = 128
@ -17,18 +16,16 @@ class InfillRequestDetails(val prefix: String, val suffix: String, val fileExten
return fromDocumentWithMaxOffset(
request.editor.document,
request.editor.caretModel.offset,
FileUtil.getFileExtension(request.file.name)
)
}
private fun fromDocumentWithMaxOffset(
document: Document,
caretOffset: Int,
fileExtension: String
): InfillRequestDetails {
val start = max(0, (caretOffset - MAX_OFFSET))
val end = min(document.textLength, (caretOffset + MAX_OFFSET))
return fromDocumentWithCustomRange(document, caretOffset, start, end, fileExtension)
return fromDocumentWithCustomRange(document, caretOffset, start, end)
}
private fun fromDocumentWithCustomRange(
@ -36,11 +33,10 @@ class InfillRequestDetails(val prefix: String, val suffix: String, val fileExten
caretOffset: Int,
start: Int,
end: Int,
fileExtension: String
): InfillRequestDetails {
val prefix: String = truncateText(document, start, caretOffset, false)
val suffix: String = truncateText(document, caretOffset, end, true)
return InfillRequestDetails(prefix, suffix, fileExtension)
return InfillRequestDetails(prefix, suffix)
}
private fun truncateText(

View file

@ -12,14 +12,23 @@ object CredentialsStore {
CredentialKey.values().forEach {
val credentialAttributes = CredentialAttributes(generateServiceName("CodeGPT", it.name))
val password = PasswordSafe.instance.getPassword(credentialAttributes)
setCredential(it, password)
// Avoid calling setCredential here since it will persist
// the password back into the PasswordSafe unnecessarily.
credentialsMap[it] = password
}
}
fun getCredential(key: CredentialKey): String? = credentialsMap[key]
fun setCredential(key: CredentialKey, password: String?) {
val prevPassword = credentialsMap[key]
credentialsMap[key] = password
if (prevPassword != password) {
val credentialAttributes = CredentialAttributes(generateServiceName("CodeGPT", key.name))
PasswordSafe.instance.setPassword(credentialAttributes, password)
}
}
fun isCredentialSet(key: CredentialKey): Boolean = !getCredential(key).isNullOrEmpty()

View file

@ -0,0 +1,41 @@
package ee.carlrobert.codegpt.settings.configuration
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.Service.Level.PROJECT
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.settings.configuration.Placeholder.BRANCH_NAME
import ee.carlrobert.codegpt.settings.configuration.Placeholder.DATE_ISO_8601
@Service(PROJECT)
class CommitMessageTemplate private constructor(project: Project) {
companion object {
fun getHtmlDescription(): String {
val placeholderDescriptions = listOf(BRANCH_NAME, DATE_ISO_8601).joinToString("\n") {
"<li><strong>${it.name}</strong>: ${it.description}</li>"
}
return buildString {
append("<html>\n")
append("<body>\n")
append("<p>Template for generating commit messages. Use the following placeholders to insert dynamic values:</p>\n")
append("<ul>$placeholderDescriptions</ul>\n")
append("</body>\n")
append("</html>")
}
}
}
private val placeholderStrategyMapping: Map<Placeholder, PlaceholderStrategy> = mapOf(
BRANCH_NAME to BranchNamePlaceholderStrategy(project),
DATE_ISO_8601 to DatePlaceholderStrategy()
)
fun getSystemPrompt(): String =
service<ConfigurationSettings>().state.commitMessagePrompt.let { template ->
placeholderStrategyMapping.entries.fold(template) { acc, (placeholder, strategy) ->
acc.replace("{${placeholder.name}}", strategy.getReplacementValue())
}
}
}

View file

@ -0,0 +1,39 @@
package ee.carlrobert.codegpt.settings.configuration
import com.intellij.openapi.project.Project
import git4idea.GitUtil
import git4idea.branch.GitBranchUtil
import java.time.LocalDate
enum class Placeholder(val description: String) {
DATE_ISO_8601("Current date in ISO 8601 format, e.g. 2021-01-01."),
BRANCH_NAME("The name of the current branch."),
PREFIX("Code before the cursor."),
SUFFIX("Code after the cursor."),
FIM_PROMPT("Prebuilt Fill-In-The-Middle (FIM) prompt using the specified template."),
}
interface PlaceholderStrategy {
fun getReplacementValue(): String
}
class DatePlaceholderStrategy : PlaceholderStrategy {
override fun getReplacementValue(): String {
return LocalDate.now().toString()
}
}
class BranchNamePlaceholderStrategy(val project: Project) : PlaceholderStrategy {
override fun getReplacementValue(): String {
return try {
val repositories = GitUtil.getRepositoryManager(project).repositories
if (repositories.isEmpty() || repositories.size != 1) {
return "BRANCH-UNKNOWN"
}
GitBranchUtil.getBranchNameOrRev(repositories[0])
} catch (ignore: Exception) {
"BRANCH-UNKNOWN"
}
}
}

View file

@ -0,0 +1,101 @@
package ee.carlrobert.codegpt.settings.service.custom
import com.intellij.openapi.ui.MessageType
import com.intellij.ui.components.JBTextField
import com.intellij.util.ui.FormBuilder
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.completions.CompletionRequestProvider
import ee.carlrobert.codegpt.completions.CompletionRequestService
import ee.carlrobert.codegpt.ui.OverlayUtil
import ee.carlrobert.llm.client.openai.completion.ErrorDetails
import ee.carlrobert.llm.completion.CompletionEventListener
import okhttp3.sse.EventSource
import java.awt.BorderLayout
import javax.swing.JButton
import javax.swing.JPanel
import javax.swing.SwingUtilities
class CustomServiceChatCompletionForm(state: CustomServiceChatCompletionSettingsState) {
private val urlField = JBTextField(state.url, 30)
private val tabbedPane = CustomServiceFormTabbedPane(state.headers, state.body)
private val testConnectionButton = JButton(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.testConnection.label")
)
init {
testConnectionButton.addActionListener { testConnection() }
}
var url: String
get() = urlField.text
set(url) {
urlField.text = url
}
var headers: MutableMap<String, String>
get() = tabbedPane.headers
set(value) {
tabbedPane.headers = value
}
var body: MutableMap<String, Any>
get() = tabbedPane.body
set(value) {
tabbedPane.body = value
}
val form: JPanel
get() = FormBuilder.createFormBuilder()
.addVerticalGap(8)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.url.label"),
JPanel(BorderLayout(8, 0)).apply {
add(urlField, BorderLayout.CENTER)
add(testConnectionButton, BorderLayout.EAST)
}
)
.addComponent(tabbedPane)
.addComponentFillVertically(JPanel(), 0)
.panel
fun resetForm(settings: CustomServiceChatCompletionSettingsState) {
urlField.text = settings.url
tabbedPane.headers = settings.headers
tabbedPane.body = settings.body
}
private fun testConnection() {
CompletionRequestService.getInstance().getCustomOpenAIChatCompletionAsync(
CompletionRequestProvider.buildCustomOpenAICompletionRequest("Hello!"),
TestConnectionEventListener()
)
}
internal inner class TestConnectionEventListener : CompletionEventListener<String?> {
override fun onMessage(value: String?, eventSource: EventSource) {
if (!value.isNullOrEmpty()) {
SwingUtilities.invokeLater {
OverlayUtil.showBalloon(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.connectionSuccess"),
MessageType.INFO,
testConnectionButton
)
eventSource.cancel()
}
}
}
override fun onError(error: ErrorDetails, ex: Throwable) {
SwingUtilities.invokeLater {
OverlayUtil.showBalloon(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.connectionFailed")
+ "\n\n"
+ error.message,
MessageType.ERROR,
testConnectionButton
)
}
}
}
}

View file

@ -0,0 +1,124 @@
package ee.carlrobert.codegpt.settings.service.custom
enum class CustomServiceChatCompletionTemplate(
val url: String,
val headers: MutableMap<String, String>,
val body: MutableMap<String, Any>
) {
ANYSCALE(
"https://api.endpoints.anyscale.com/v1/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(
mapOf(
"model" to "mistralai/Mixtral-8x7B-Instruct-v0.1",
"max_tokens" to 1024
)
)
),
AZURE(
"https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/chat/completions?api-version=2023-05-15",
getDefaultHeaders("api-key", "\$CUSTOM_SERVICE_API_KEY"),
getDefaultBodyParams(emptyMap())
),
DEEP_INFRA(
"https://api.deepinfra.com/v1/openai/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(
mapOf(
"model" to "meta-llama/Llama-2-70b-chat-hf",
"max_tokens" to 1024
)
)
),
FIREWORKS(
"https://api.fireworks.ai/inference/v1/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(
mapOf(
"model" to "accounts/fireworks/models/llama-v2-7b-chat",
"max_tokens" to 1024
)
)
),
GROQ(
"https://api.groq.com/openai/v1/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(
mapOf(
"model" to "codellama-34b",
"max_tokens" to 1024
)
)
),
OPENAI(
"https://api.openai.com/v1/chat/completions",
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
getDefaultBodyParams(
mapOf(
"model" to "gpt-4",
"max_tokens" to 1024
)
)
),
PERPLEXITY(
"https://api.perplexity.ai/chat/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(
mapOf(
"model" to "codellama",
"max_tokens" to 1024
)
)
),
TOGETHER(
"https://api.together.xyz/v1/chat/completions",
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
getDefaultBodyParams(
mapOf(
"model" to "deepseek-ai/deepseek-coder-33b-instruct",
"max_tokens" to 1024
)
)
),
OLLAMA(
"http://localhost:11434/v1/chat/completions",
getDefaultHeaders(),
getDefaultBodyParams(mapOf("model" to "codellama"))
),
LLAMA_CPP(
"http://localhost:8080/v1/chat/completions",
getDefaultHeaders(),
getDefaultBodyParams(emptyMap())
);
}
private fun getDefaultHeadersWithAuthentication(): MutableMap<String, String> {
return getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY")
}
private fun getDefaultHeaders(): MutableMap<String, String> {
return getDefaultHeaders(emptyMap())
}
private fun getDefaultHeaders(key: String, value: String): MutableMap<String, String> {
return getDefaultHeaders(mapOf(key to value))
}
private fun getDefaultHeaders(additionalHeaders: Map<String, String>): MutableMap<String, String> {
val defaultHeaders = mutableMapOf(
"Content-Type" to "application/json",
"X-LLM-Application-Tag" to "codegpt"
)
defaultHeaders.putAll(additionalHeaders)
return defaultHeaders
}
private fun getDefaultBodyParams(additionalParams: Map<String, Any>): MutableMap<String, Any> {
val defaultParams = mutableMapOf<String, Any>(
"stream" to true,
"messages" to "\$OPENAI_MESSAGES",
"temperature" to 0.1
)
defaultParams.putAll(additionalParams)
return defaultParams
}

View file

@ -0,0 +1,182 @@
package ee.carlrobert.codegpt.settings.service.custom
import com.intellij.icons.AllIcons.General
import com.intellij.ide.HelpTooltip
import com.intellij.openapi.ui.ComboBox
import com.intellij.openapi.ui.MessageType
import com.intellij.openapi.ui.panel.ComponentPanelBuilder
import com.intellij.ui.EnumComboBoxModel
import com.intellij.ui.components.JBCheckBox
import com.intellij.ui.components.JBLabel
import com.intellij.ui.components.JBTextField
import com.intellij.util.ui.FormBuilder
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate
import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails
import ee.carlrobert.codegpt.completions.CompletionRequestService
import ee.carlrobert.codegpt.settings.configuration.Placeholder
import ee.carlrobert.codegpt.ui.OverlayUtil
import ee.carlrobert.llm.client.openai.completion.ErrorDetails
import ee.carlrobert.llm.completion.CompletionEventListener
import okhttp3.sse.EventSource
import org.apache.commons.text.StringEscapeUtils
import java.awt.BorderLayout
import java.awt.FlowLayout
import javax.swing.Box
import javax.swing.JButton
import javax.swing.JPanel
import javax.swing.SwingUtilities
class CustomServiceCodeCompletionForm(state: CustomServiceCodeCompletionSettingsState) {
private val featureEnabledCheckBox = JBCheckBox(
CodeGPTBundle.get("codeCompletionsForm.enableFeatureText"),
state.codeCompletionsEnabled
)
private val promptTemplateComboBox =
ComboBox(EnumComboBoxModel(InfillPromptTemplate::class.java)).apply {
selectedItem = state.infillTemplate
setSelectedItem(InfillPromptTemplate.LLAMA)
addItemListener {
updatePromptTemplateHelpTooltip(it.item as InfillPromptTemplate)
}
}
private val promptTemplateHelpText = JBLabel(General.ContextHelp)
private val urlField = JBTextField(state.url, 30)
private val tabbedPane = CustomServiceFormTabbedPane(state.headers, state.body)
private val testConnectionButton = JButton(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.testConnection.label")
)
init {
testConnectionButton.addActionListener { testConnection() }
updatePromptTemplateHelpTooltip(state.infillTemplate)
}
var codeCompletionsEnabled: Boolean
get() = featureEnabledCheckBox.isSelected
set(enabled) {
featureEnabledCheckBox.isSelected = enabled
}
var infillTemplate: InfillPromptTemplate
get() = promptTemplateComboBox.item
set(template) {
promptTemplateComboBox.selectedItem = template
}
var url: String
get() = urlField.text
set(url) {
urlField.text = url
}
var headers: MutableMap<String, String>
get() = tabbedPane.headers
set(value) {
tabbedPane.headers = value
}
var body: MutableMap<String, Any>
get() = tabbedPane.body
set(value) {
tabbedPane.body = value
}
val form: JPanel
get() = FormBuilder.createFormBuilder()
.addVerticalGap(8)
.addComponent(featureEnabledCheckBox)
.addVerticalGap(4)
.addLabeledComponent(
"FIM template:",
JPanel(FlowLayout(FlowLayout.LEADING, 0, 0)).apply {
add(promptTemplateComboBox)
add(Box.createHorizontalStrut(4))
add(promptTemplateHelpText)
})
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.url.label"),
JPanel(BorderLayout(8, 0)).apply {
add(urlField, BorderLayout.CENTER)
add(testConnectionButton, BorderLayout.EAST)
}
)
.addComponent(tabbedPane)
.addComponent(ComponentPanelBuilder.createCommentComponent(getHtmlDescription(), true, 100))
.addComponentFillVertically(JPanel(), 0)
.panel
private fun getHtmlDescription(): String {
val placeholderDescriptions = listOf(
Placeholder.FIM_PROMPT,
Placeholder.PREFIX,
Placeholder.SUFFIX
).joinToString("\n") {
"<li><strong>\$${it.name}</strong>: ${it.description}</li>"
}
return buildString {
append("<html>\n")
append("<body>\n")
append("<p>Use the following placeholders to insert dynamic values:</p>\n")
append("<ul>$placeholderDescriptions</ul>\n")
append("</body>\n")
append("</html>")
}
}
fun resetForm(settings: CustomServiceCodeCompletionSettingsState) {
featureEnabledCheckBox.isSelected = settings.codeCompletionsEnabled
promptTemplateComboBox.selectedItem = settings.infillTemplate
urlField.text = settings.url
tabbedPane.headers = settings.headers
tabbedPane.body = settings.body
updatePromptTemplateHelpTooltip(settings.infillTemplate)
}
private fun testConnection() {
CompletionRequestService.getInstance().getCustomOpenAICompletionAsync(
CodeCompletionRequestFactory.buildCustomRequest(InfillRequestDetails("Hello", "!")),
TestConnectionEventListener()
)
}
internal inner class TestConnectionEventListener : CompletionEventListener<String?> {
override fun onMessage(value: String?, eventSource: EventSource) {
if (!value.isNullOrEmpty()) {
SwingUtilities.invokeLater {
OverlayUtil.showBalloon(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.connectionSuccess"),
MessageType.INFO,
testConnectionButton
)
eventSource.cancel()
}
}
}
override fun onError(error: ErrorDetails, ex: Throwable) {
SwingUtilities.invokeLater {
OverlayUtil.showBalloon(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.connectionFailed")
+ "\n\n"
+ error.message,
MessageType.ERROR,
testConnectionButton
)
}
}
}
private fun updatePromptTemplateHelpTooltip(template: InfillPromptTemplate) {
promptTemplateHelpText.setToolTipText(null)
val description = StringEscapeUtils.escapeHtml4(template.buildPrompt("PREFIX", "SUFFIX"))
HelpTooltip()
.setTitle(template.toString())
.setDescription("<html><p>$description</p></html>")
.installOn(promptTemplateHelpText)
}
}

View file

@ -0,0 +1,73 @@
package ee.carlrobert.codegpt.settings.service.custom
enum class CustomServiceCodeCompletionTemplate(
val url: String,
val headers: MutableMap<String, String>,
val body: MutableMap<String, Any>
) {
ANYSCALE(
"https://api.endpoints.anyscale.com/v1/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(mapOf("model" to "codellama/CodeLlama-70b-Instruct-hf"))
),
AZURE(
"https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/completions?api-version=2023-05-15",
getDefaultHeaders("api-key", "\$CUSTOM_SERVICE_API_KEY"),
getDefaultBodyParams(emptyMap())
),
DEEP_INFRA(
"https://api.deepinfra.com/v1/inference/codellama/CodeLlama-70b-Instruct-hf",
getDefaultHeadersWithAuthentication(),
mutableMapOf("input" to "\$FIM_PROMPT")
),
FIREWORKS(
"https://api.fireworks.ai/inference/v1/completions",
getDefaultHeadersWithAuthentication(),
getDefaultBodyParams(mapOf("model" to "accounts/fireworks/models/starcoder-16b"))
),
OPENAI(
"https://api.openai.com/v1/completions",
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
mutableMapOf(
"stream" to true,
"prompt" to "\$PREFIX",
"suffix" to "\$SUFFIX",
"model" to "gpt-3.5-turbo-instruct",
"temperature" to 0.2,
"max_tokens" to 24
)
),
TOGETHER(
"https://api.together.xyz/v1/completions",
getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY"),
getDefaultBodyParams(mapOf("model" to "codellama/CodeLlama-70b-hf"))
)
}
private fun getDefaultHeadersWithAuthentication(): MutableMap<String, String> {
return getDefaultHeaders("Authorization", "Bearer \$CUSTOM_SERVICE_API_KEY")
}
private fun getDefaultHeaders(key: String, value: String): MutableMap<String, String> {
return getDefaultHeaders(mapOf(key to value))
}
private fun getDefaultHeaders(additionalHeaders: Map<String, String>): MutableMap<String, String> {
val defaultHeaders = mutableMapOf(
"Content-Type" to "application/json",
"X-LLM-Application-Tag" to "codegpt"
)
defaultHeaders.putAll(additionalHeaders)
return defaultHeaders
}
private fun getDefaultBodyParams(additionalParams: Map<String, Any>): MutableMap<String, Any> {
val defaultParams = mutableMapOf<String, Any>(
"stream" to true,
"prompt" to "\$FIM_PROMPT",
"temperature" to 0.2,
"max_tokens" to 24
)
defaultParams.putAll(additionalParams)
return defaultParams
}

View file

@ -0,0 +1,148 @@
package ee.carlrobert.codegpt.settings.service.custom
import com.intellij.icons.AllIcons.General
import com.intellij.ide.HelpTooltip
import com.intellij.openapi.components.service
import com.intellij.openapi.ui.ComboBox
import com.intellij.ui.EnumComboBoxModel
import com.intellij.ui.TitledSeparator
import com.intellij.ui.components.JBLabel
import com.intellij.ui.components.JBPasswordField
import com.intellij.util.ui.FormBuilder
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
import ee.carlrobert.codegpt.ui.UIUtil
import java.awt.FlowLayout
import java.net.MalformedURLException
import java.net.URL
import javax.swing.Box
import javax.swing.JPanel
import javax.swing.JTabbedPane
class CustomServiceForm {
private val apiKeyField = JBPasswordField().apply {
columns = 30
text = getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY)
}
private val templateHelpText = JBLabel(General.ContextHelp)
private val templateComboBox = ComboBox(EnumComboBoxModel(CustomServiceTemplate::class.java))
private val chatCompletionsForm: CustomServiceChatCompletionForm
private val codeCompletionsForm: CustomServiceCodeCompletionForm
private val tabbedPane: JTabbedPane
init {
val state = service<CustomServiceSettings>().state
chatCompletionsForm = CustomServiceChatCompletionForm(state.chatCompletionSettings)
codeCompletionsForm = CustomServiceCodeCompletionForm(state.codeCompletionSettings)
tabbedPane = JTabbedPane().apply {
add(CodeGPTBundle.get("shared.chatCompletions"), chatCompletionsForm.form)
add(CodeGPTBundle.get("shared.codeCompletions"), codeCompletionsForm.form)
}
templateComboBox.selectedItem = state.template
templateComboBox.addItemListener {
val template = it.item as CustomServiceTemplate
updateTemplateHelpTextTooltip(template)
chatCompletionsForm.run {
url = template.chatCompletionTemplate.url
headers = template.chatCompletionTemplate.headers
body = template.chatCompletionTemplate.body
}
if (template.codeCompletionTemplate != null) {
codeCompletionsForm.run {
url = template.codeCompletionTemplate.url
headers = template.codeCompletionTemplate.headers
body = template.codeCompletionTemplate.body
}
tabbedPane.setEnabledAt(1, true)
} else {
tabbedPane.selectedIndex = 0
tabbedPane.setEnabledAt(1, false)
}
}
updateTemplateHelpTextTooltip(state.template)
}
fun getForm(): JPanel = FormBuilder.createFormBuilder()
.addComponent(TitledSeparator(CodeGPTBundle.get("shared.configuration")))
.addComponent(
FormBuilder.createFormBuilder()
.setFormLeftIndent(16)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.presetTemplate.label"),
JPanel(FlowLayout(FlowLayout.LEADING, 0, 0)).apply {
add(templateComboBox)
add(Box.createHorizontalStrut(8))
add(templateHelpText)
}
)
.addLabeledComponent(
CodeGPTBundle.get("settingsConfigurable.shared.apiKey.label"),
apiKeyField
)
.addComponentToRightColumn(
UIUtil.createComment("settingsConfigurable.service.custom.openai.apiKey.comment")
)
.addVerticalGap(4)
.addComponent(tabbedPane)
.panel
)
.panel
fun getApiKey() = String(apiKeyField.password).ifEmpty { null }
fun isModified() = service<CustomServiceSettings>().state.run {
templateComboBox.selectedItem != template
|| chatCompletionsForm.url != chatCompletionSettings.url
|| chatCompletionsForm.headers != chatCompletionSettings.headers
|| chatCompletionsForm.body != chatCompletionSettings.body
|| codeCompletionsForm.codeCompletionsEnabled != codeCompletionSettings.codeCompletionsEnabled
|| codeCompletionsForm.infillTemplate != codeCompletionSettings.infillTemplate
|| codeCompletionsForm.url != codeCompletionSettings.url
|| codeCompletionsForm.headers != codeCompletionSettings.headers
|| codeCompletionsForm.body != codeCompletionSettings.body
|| getApiKey() != getCredential(CredentialKey.CUSTOM_SERVICE_API_KEY)
}
fun applyChanges() {
service<CustomServiceSettings>().state.run {
template = templateComboBox.item
chatCompletionSettings = CustomServiceChatCompletionSettingsState().apply {
url = chatCompletionsForm.url
headers = chatCompletionsForm.headers
body = chatCompletionsForm.body
}
codeCompletionSettings = CustomServiceCodeCompletionSettingsState().apply {
codeCompletionsEnabled = codeCompletionsForm.codeCompletionsEnabled
infillTemplate = codeCompletionsForm.infillTemplate
url = codeCompletionsForm.url
headers = codeCompletionsForm.headers
body = codeCompletionsForm.body
}
}
}
fun resetForm() {
service<CustomServiceSettings>().state.run {
templateComboBox.item = template
chatCompletionsForm.resetForm(chatCompletionSettings)
codeCompletionsForm.resetForm(codeCompletionSettings)
}
}
private fun updateTemplateHelpTextTooltip(template: CustomServiceTemplate) {
templateHelpText.toolTipText = null
try {
HelpTooltip()
.setTitle(template.providerName)
.setBrowserLink(
CodeGPTBundle.get("settingsConfigurable.service.custom.openai.linkToDocs"),
URL(template.docsUrl)
)
.installOn(templateHelpText)
} catch (e: MalformedURLException) {
throw RuntimeException(e)
}
}
}

View file

@ -0,0 +1,76 @@
package ee.carlrobert.codegpt.settings.service.custom
import com.intellij.openapi.components.*
import com.intellij.util.xmlb.annotations.OptionTag
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate
import ee.carlrobert.codegpt.util.MapConverter
@Service
@State(
name = "CodeGPT_CustomServiceSettings",
storages = [Storage("CodeGPT_CustomServiceSettings.xml")]
)
class CustomServiceSettings :
SimplePersistentStateComponent<CustomServiceState>(CustomServiceState()) {
override fun loadState(state: CustomServiceState) {
if (state.url != null || state.body.isNotEmpty() || state.headers.isNotEmpty()) {
super.loadState(this.state.apply {
// Migrate old settings
template = state.template
chatCompletionSettings.url = state.url
chatCompletionSettings.body = state.body
chatCompletionSettings.headers = state.headers
url = null
body = mutableMapOf()
headers = mutableMapOf()
})
} else {
super.loadState(state)
}
}
}
class CustomServiceState : BaseState() {
var template by enum(CustomServiceTemplate.OPENAI)
var chatCompletionSettings by property(CustomServiceChatCompletionSettingsState())
var codeCompletionSettings by property(CustomServiceCodeCompletionSettingsState())
@Deprecated("", ReplaceWith("this.chatCompletionSettings.url"))
var url by string()
@Deprecated("", ReplaceWith("this.chatCompletionSettings.headers"))
var headers by map<String, String>()
@get:OptionTag(converter = MapConverter::class)
@Deprecated("", ReplaceWith("this.chatCompletionSettings.body"))
var body by map<String, Any>()
}
class CustomServiceChatCompletionSettingsState : BaseState() {
var url by string(CustomServiceChatCompletionTemplate.OPENAI.url)
var headers by map<String, String>()
@get:OptionTag(converter = MapConverter::class)
var body by map<String, Any>()
init {
headers.putAll(CustomServiceChatCompletionTemplate.OPENAI.headers)
body.putAll(CustomServiceChatCompletionTemplate.OPENAI.body)
}
}
class CustomServiceCodeCompletionSettingsState : BaseState() {
var codeCompletionsEnabled by property(true)
var infillTemplate by enum(InfillPromptTemplate.OPENAI)
var url by string(CustomServiceCodeCompletionTemplate.OPENAI.url)
var headers by map<String, String>()
@get:OptionTag(converter = MapConverter::class)
var body by map<String, Any>()
init {
headers.putAll(CustomServiceCodeCompletionTemplate.OPENAI.headers)
body.putAll(CustomServiceCodeCompletionTemplate.OPENAI.body)
}
}

View file

@ -0,0 +1,69 @@
package ee.carlrobert.codegpt.settings.service.custom
enum class CustomServiceTemplate(
val providerName: String,
val docsUrl: String,
val chatCompletionTemplate: CustomServiceChatCompletionTemplate,
val codeCompletionTemplate: CustomServiceCodeCompletionTemplate? = null
) {
ANYSCALE(
"Anyscale",
"https://docs.endpoints.anyscale.com/",
CustomServiceChatCompletionTemplate.ANYSCALE,
CustomServiceCodeCompletionTemplate.ANYSCALE,
),
AZURE(
"Azure OpenAI",
"https://learn.microsoft.com/en-us/azure/ai-services/openai/reference",
CustomServiceChatCompletionTemplate.AZURE,
CustomServiceCodeCompletionTemplate.AZURE
),
DEEP_INFRA(
"DeepInfra",
"https://deepinfra.com/docs/advanced/openai_api",
CustomServiceChatCompletionTemplate.DEEP_INFRA,
CustomServiceCodeCompletionTemplate.DEEP_INFRA
),
FIREWORKS(
"Fireworks",
"https://readme.fireworks.ai/reference/createchatcompletion",
CustomServiceChatCompletionTemplate.FIREWORKS,
CustomServiceCodeCompletionTemplate.FIREWORKS
),
GROQ(
"Groq",
"https://docs.api.groq.com/md/openai.oas.html",
CustomServiceChatCompletionTemplate.GROQ
),
OPENAI(
"OpenAI",
"https://platform.openai.com/docs/api-reference/chat",
CustomServiceChatCompletionTemplate.OPENAI,
CustomServiceCodeCompletionTemplate.OPENAI
),
PERPLEXITY(
"Perplexity AI",
"https://docs.perplexity.ai/reference/post_chat_completions",
CustomServiceChatCompletionTemplate.PERPLEXITY
),
TOGETHER(
"Together AI",
"https://docs.together.ai/docs/openai-api-compatibility",
CustomServiceChatCompletionTemplate.TOGETHER,
CustomServiceCodeCompletionTemplate.TOGETHER
),
OLLAMA(
"Ollama",
"https://github.com/ollama/ollama/blob/main/docs/openai.md",
CustomServiceChatCompletionTemplate.OLLAMA
),
LLAMA_CPP(
"LLaMA C/C++",
"https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md",
CustomServiceChatCompletionTemplate.LLAMA_CPP
);
override fun toString(): String {
return providerName
}
}

View file

@ -0,0 +1,37 @@
package ee.carlrobert.codegpt.util
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.project.Project
import com.intellij.openapi.project.ProjectManager
import com.intellij.openapi.wm.IdeFocusManager
object ApplicationUtil {
@JvmStatic
fun isUnitTestingMode(): Boolean {
val app = ApplicationManager.getApplication()
return app != null && app.isUnitTestMode
}
@JvmStatic
fun findCurrentProject(): Project? {
val frame = IdeFocusManager.getGlobalInstance().lastFocusedFrame
val project = frame?.project
if (isValidProject(project)) {
return project
}
return findProjectFromOpenProjects()
}
private fun findProjectFromOpenProjects(): Project? {
for (project in ProjectManager.getInstance().openProjects) {
if (isValidProject(project)) {
return project
}
}
return null
}
private fun isValidProject(project: Project?): Boolean {
return project != null && !project.isDisposed && !project.isDefault
}
}

View file

@ -0,0 +1,30 @@
package ee.carlrobert.codegpt.util
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.core.type.TypeReference
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
import com.intellij.util.xmlb.Converter
abstract class BaseConverter<T> protected constructor(private val typeReference: TypeReference<T>) : Converter<T>() {
private val objectMapper: ObjectMapper = ObjectMapper()
.registerModule(Jdk8Module())
.registerModule(JavaTimeModule())
override fun fromString(value: String): T? {
try {
return objectMapper.readValue(value, typeReference)
} catch (e: JsonProcessingException) {
throw RuntimeException("Unable to deserialize conversations", e)
}
}
override fun toString(value: T & Any): String? {
try {
return objectMapper.writeValueAsString(value)
} catch (e: JsonProcessingException) {
throw RuntimeException("Unable to serialize conversations", e)
}
}
}

View file

@ -0,0 +1,152 @@
package ee.carlrobert.codegpt.util
import com.intellij.codeInsight.daemon.DaemonCodeAnalyzer
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.PathManager
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.editor.Document
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.editor.EditorFactory
import com.intellij.openapi.editor.EditorKind
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.fileEditor.FileEditorManager
import com.intellij.openapi.fileEditor.TextEditor
import com.intellij.openapi.fileEditor.impl.FileEditorManagerImpl
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiDocumentManager
import com.intellij.psi.codeStyle.CodeStyleManager
import com.intellij.testFramework.LightVirtualFile
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
object EditorUtil {
@JvmStatic
fun createEditor(project: Project, fileExtension: String, code: String): Editor {
val timestamp = DateTimeFormatter.ofPattern("yyyyMMddHHmmss").format(LocalDateTime.now())
val fileName = "temp_$timestamp$fileExtension"
val lightVirtualFile = LightVirtualFile(
String.format("%s/%s", PathManager.getTempPath(), fileName),
code
)
val existingDocument = FileDocumentManager.getInstance().getDocument(lightVirtualFile)
val document = existingDocument ?: EditorFactory.getInstance().createDocument(code)
disableHighlighting(project, document)
return EditorFactory.getInstance().createEditor(
document,
project,
lightVirtualFile,
true,
EditorKind.MAIN_EDITOR
)
}
@JvmStatic
fun updateEditorDocument(editor: Editor, content: String) {
val document = editor.document
val application = ApplicationManager.getApplication()
val updateDocumentRunnable = Runnable {
application.runWriteAction {
WriteCommandAction.runWriteCommandAction(editor.project) {
document.replaceString(0, document.textLength, content)
editor.component.repaint()
editor.component.revalidate()
}
}
}
if (application.isUnitTestMode) {
application.invokeAndWait(updateDocumentRunnable)
} else {
application.invokeLater(updateDocumentRunnable)
}
}
@JvmStatic
fun hasSelection(editor: Editor?): Boolean {
return editor?.selectionModel?.hasSelection() == true
}
@JvmStatic
fun getSelectedEditor(project: Project): Editor? {
val editorManager = FileEditorManager.getInstance(project)
return editorManager?.selectedTextEditor
}
@JvmStatic
fun getSelectedEditorSelectedText(project: Project): String? {
val selectedEditor = getSelectedEditor(project)
return selectedEditor?.selectionModel?.selectedText
}
@JvmStatic
fun isSelectedEditor(editor: Editor): Boolean {
val project = editor.project
if (project != null && !project.isDisposed) {
val editorManager = FileEditorManager.getInstance(project) ?: return false
if (editorManager is FileEditorManagerImpl) {
return editor == editorManager.getSelectedTextEditor(true)
}
val current = editorManager.selectedEditor
return (current is TextEditor) && editor == current.editor
}
return false
}
@JvmStatic
fun isMainEditorTextSelected(project: Project): Boolean {
return hasSelection(getSelectedEditor(project))
}
@JvmStatic
fun replaceMainEditorSelection(project: Project, text: String) {
val application = ApplicationManager.getApplication()
application.invokeLater {
application.runWriteAction {
WriteCommandAction.runWriteCommandAction(project) {
val editor = getSelectedEditor(project)
editor?.let {
val selectionModel = editor.selectionModel
val startOffset = selectionModel.selectionStart
val endOffset = selectionModel.selectionEnd
val document = editor.document
document.replaceString(startOffset, endOffset, text)
if (ConfigurationSettings.getCurrentState().isAutoFormattingEnabled) {
reformatDocument(project, document, startOffset, endOffset)
}
editor.contentComponent.requestFocus()
selectionModel.removeSelection()
}
}
}
}
}
@JvmStatic
fun reformatDocument(
project: Project,
document: Document,
startOffset: Int,
endOffset: Int
) {
val psiDocumentManager = PsiDocumentManager.getInstance(project)
psiDocumentManager.commitDocument(document)
val psiFile = psiDocumentManager.getPsiFile(document)
psiFile?.let {
CodeStyleManager.getInstance(project).reformatText(psiFile, startOffset, endOffset)
}
}
@JvmStatic
fun disableHighlighting(project: Project, document: Document) {
val psiFile = PsiDocumentManager.getInstance(project).getPsiFile(document)
psiFile?.let {
DaemonCodeAnalyzer.getInstance(project).setHighlightingEnabled(psiFile, false)
}
}
}

View file

@ -0,0 +1,5 @@
package ee.carlrobert.codegpt.util
import com.fasterxml.jackson.core.type.TypeReference
class MapConverter : BaseConverter<Map<String, Any>>(object : TypeReference<Map<String, Any>>() {})

View file

@ -0,0 +1,42 @@
package ee.carlrobert.codegpt.util
import com.vladsch.flexmark.html.HtmlRenderer
import com.vladsch.flexmark.parser.Parser
import com.vladsch.flexmark.util.data.MutableDataSet
import ee.carlrobert.codegpt.toolwindow.chat.ResponseNodeRenderer
import java.util.regex.Pattern
object MarkdownUtil {
/**
* Splits a given string into a list of strings where each element is either a code block
* surrounded by triple backticks or a non-code block text.
*
* @param inputMarkdown The input markdown formatted string to be split.
* @return A list of strings where each element is a code block or a non-code block text from the
* input string.
*/
@JvmStatic
fun splitCodeBlocks(inputMarkdown: String): List<String> {
val result: MutableList<String> = ArrayList()
val pattern = Pattern.compile("(?s)```.*?```")
val matcher = pattern.matcher(inputMarkdown)
var start = 0
while (matcher.find()) {
result.add(inputMarkdown.substring(start, matcher.start()))
result.add(matcher.group())
start = matcher.end()
}
result.add(inputMarkdown.substring(start))
return result.stream().filter(String::isNotBlank).toList()
}
@JvmStatic
fun convertMdToHtml(message: String?): String {
val options = MutableDataSet()
val document = Parser.builder(options).build().parse(message!!)
return HtmlRenderer.builder(options)
.nodeRendererFactory(ResponseNodeRenderer.Factory())
.build()
.render(document)
}
}

View file

@ -0,0 +1,4 @@
package ee.carlrobert.codegpt.util.file
@JvmRecord
data class FileExtensionLanguageDetails(val extension: String, val value: String)

View file

@ -0,0 +1,213 @@
package ee.carlrobert.codegpt.util.file
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.core.type.TypeReference
import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.progress.ProgressIndicator
import com.intellij.openapi.vfs.VirtualFile
import ee.carlrobert.codegpt.CodeGPTPlugin
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.Writer
import java.net.URL
import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.nio.file.StandardOpenOption
import java.text.DecimalFormat
import java.util.Objects
import java.util.Optional
import java.util.regex.Pattern
object FileUtil {
private val LOG = Logger.getInstance(FileUtil::class.java)
@JvmStatic
fun createFile(directoryPath: String, fileName: String?, fileContent: String?): File {
try {
tryCreateDirectory(directoryPath)
return Files.writeString(
Path.of(directoryPath, fileName),
fileContent,
StandardOpenOption.CREATE
).toFile()
} catch (e: IOException) {
throw RuntimeException("Failed to create file", e)
}
}
@JvmStatic
@Throws(IOException::class)
fun copyFileWithProgress(
fileName: String,
url: URL,
bytesRead: LongArray,
fileSize: Long,
indicator: ProgressIndicator
) {
tryCreateDirectory(CodeGPTPlugin.getLlamaModelsPath())
Channels.newChannel(url.openStream()).use { readableByteChannel ->
FileOutputStream(
CodeGPTPlugin.getLlamaModelsPath() + File.separator + fileName
).use { fileOutputStream ->
val buffer = ByteBuffer.allocateDirect(1024 * 10)
while (readableByteChannel.read(buffer) != -1) {
if (indicator.isCanceled) {
readableByteChannel.close()
break
}
buffer.flip()
bytesRead[0] += fileOutputStream.channel.write(buffer).toLong()
buffer.clear()
indicator.fraction = bytesRead[0].toDouble() / fileSize
}
}
}
}
@JvmStatic
fun getEditorFile(editor: Editor): VirtualFile? {
return FileDocumentManager.getInstance().getFile(editor.document)
}
private fun tryCreateDirectory(directoryPath: String) {
try {
if (!com.intellij.openapi.util.io.FileUtil.exists(directoryPath)) {
if (!com.intellij.openapi.util.io.FileUtil.createDirectory(
Path.of(directoryPath).toFile()
)
) {
throw IOException("Failed to create directory: $directoryPath")
}
}
} catch (e: IOException) {
throw RuntimeException("Failed to create directory", e)
}
}
@JvmStatic
fun getFileExtension(filename: String?): String {
val pattern = Pattern.compile("[^.]+$")
val matcher = filename?.let { pattern.matcher(it) }
if (matcher?.find() == true) {
return matcher.group()
}
return ""
}
@JvmStatic
fun findLanguageExtensionMapping(language: String): Map.Entry<String, String> {
val defaultValue = mapOf("Text" to ".txt").entries.first()
val mapper = ObjectMapper()
val extensionToLanguageMappings: List<FileExtensionLanguageDetails>
val languageToExtensionMappings: List<LanguageFileExtensionDetails>
try {
extensionToLanguageMappings = mapper.readValue(
getResourceContent("/fileExtensionLanguageMappings.json"),
object : TypeReference<List<FileExtensionLanguageDetails>>() {
})
languageToExtensionMappings = mapper.readValue(
getResourceContent("/languageFileExtensionMappings.json"),
object : TypeReference<List<LanguageFileExtensionDetails>>() {
})
} catch (e: JsonProcessingException) {
LOG.error("Unable to extract file extension", e)
return defaultValue
}
return findFirstExtension(languageToExtensionMappings, language)
.or {
extensionToLanguageMappings.stream()
.filter { it.extension.equals(language, ignoreCase = true) }
.findFirst()
.flatMap { findFirstExtension(languageToExtensionMappings, it.value) }
}.orElse(defaultValue)
}
fun isUtf8File(filePath: String?): Boolean {
val path = filePath?.let { Paths.get(it) }
try {
Files.newBufferedReader(path).use { reader ->
val c = reader.read()
if (c >= 0) {
reader.transferTo(Writer.nullWriter())
}
return true
}
} catch (e: Exception) {
return false
}
}
@JvmStatic
fun getImageMediaType(fileName: String?): String {
return when (val fileExtension = getFileExtension(fileName)) {
"png" -> "image/png"
"jpg", "jpeg" -> "image/jpeg"
else -> throw IllegalArgumentException("Unsupported image type: $fileExtension")
}
}
@JvmStatic
fun getResourceContent(name: String?): String {
try {
Objects.requireNonNull(name?.let { FileUtil::class.java.getResourceAsStream(it) }).use { stream ->
return String(stream.readAllBytes(), StandardCharsets.UTF_8)
}
} catch (e: IOException) {
throw RuntimeException("Unable to read resource", e)
}
}
@JvmStatic
fun convertFileSize(fileSizeInBytes: Long): String {
val units = arrayOf("B", "KB", "MB", "GB")
var unitIndex = 0
var fileSize = fileSizeInBytes.toDouble()
while (fileSize >= 1024 && unitIndex < units.size - 1) {
fileSize /= 1024.0
unitIndex++
}
return DecimalFormat("#.##").format(fileSize) + " " + units[unitIndex]
}
@JvmStatic
fun convertLongValue(value: Long): String {
if (value >= 1000000) {
return (value / 1000000).toString() + "M"
}
if (value >= 1000) {
return (value / 1000).toString() + "K"
}
return value.toString()
}
@JvmStatic
fun findFirstExtension(
languageFileExtensionMappings: List<LanguageFileExtensionDetails>,
language: String
): Optional<Map.Entry<String, String>> {
return languageFileExtensionMappings.stream()
.filter { language.equals(it.name, ignoreCase = true)
&& it.extensions != null
&& it.extensions.stream().anyMatch(String::isNotBlank) }
.findFirst()
.map { java.util.Map.entry(it.name,
it.extensions?.stream()?.filter(String::isNotBlank)?.findFirst()?.orElse("") ?: ""
) }
}
}

View file

@ -0,0 +1,4 @@
package ee.carlrobert.codegpt.util.file
@JvmRecord
data class LanguageFileExtensionDetails(val name: String, val type: String, val extensions: List<String>?)

View file

@ -9,8 +9,8 @@
<projectListeners>
<listener topic="com.intellij.codeInsight.lookup.LookupManagerListener"
class="ee.carlrobert.codegpt.completions.MethodNameLookupListener"/>
<listener class="ee.carlrobert.codegpt.toolwindow.chat.ChatToolWindowListener"
topic="com.intellij.openapi.wm.ex.ToolWindowManagerListener"/>
<listener topic="com.intellij.openapi.wm.ex.ToolWindowManagerListener"
class="ee.carlrobert.codegpt.toolwindow.chat.ChatToolWindowListener"/>
</projectListeners>
<extensions defaultExtensionNs="com.intellij">
@ -33,7 +33,6 @@
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.azure.AzureSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.anthropic.AnthropicSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.openai.OpenAISettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.you.YouSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.service.llama.LlamaSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.IncludedFilesSettings"/>

View file

@ -2,9 +2,10 @@ project.label=CodeGPT
notification.group.name=CodeGPT notification group
action.generateCommitMessage.title=Generate Message
action.generateCommitMessage.description=Generate commit message
action.generateCommitMessage.serviceWarning=Messages can only be generated with OpenAI or Azure service
action.generateCommitMessage.serviceWarning=Messages can only be generated with OpenAI, Custom OpenAI, or Azure service
action.generateCommitMessage.missingCredentials=Credentials not provided
action.includeFilesInContext.title=Include In Context...
action.includeFileInContext.title=Include File In Context...
action.includeFilesInContext.dialog.title=Include In Context
action.includeFilesInContext.dialog.description=Choose the files that you wish to include in the final prompt
action.includeFilesInContext.dialog.repeatableContext.label=Repeatable context:
@ -62,6 +63,8 @@ settingsConfigurable.service.llama.threads.label=Threads:
settingsConfigurable.service.llama.threads.comment=The number of threads available to execute the model. It is not recommended to specify a number greater than the number of processor cores.
settingsConfigurable.service.llama.additionalParameters.label=Additional parameters:
settingsConfigurable.service.llama.additionalParameters.comment=<html>Additional command-line parameters for the server startup process, separated by commas. See the full <a href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md">list of options</a>.<p><i>Example: "--n-gpu-layers, 1, --no-mmap, --mlock"</i></p></html>
settingsConfigurable.service.llama.additionalBuildParameters.label=Additional build parameters:
settingsConfigurable.service.llama.additionalBuildParameters.comment=<html>Additional command-line parameters for the server build process, separated by commas. See the full <a href="https://github.com/ggerganov/llama.cpp/tree/master?tab=readme-ov-file#build">list of build options</a>.<p><i>Example: "LLAMA_CUBLAS=1,CUDA_DOCKER_ARCH=all"</i></p></html>
settingsConfigurable.service.llama.baseHost.label=Base host:
settingsConfigurable.service.llama.baseHost.comment=URL to existing LLama server
settingsConfigurable.service.llama.startServer.label=Start server
@ -113,9 +116,8 @@ settingsConfigurable.service.custom.openai.url.label=URL:
settingsConfigurable.service.custom.openai.linkToDocs=Link to API docs
settingsConfigurable.service.custom.openai.connectionSuccess=Connection successful.
settingsConfigurable.service.custom.openai.connectionFailed=Connection failed.
configurationConfigurable.section.commitMessage.title=Commit Message
configurationConfigurable.section.commitMessage.systemPromptField.label=Prompt:
configurationConfigurable.section.commitMessage.systemPromptField.comment=Custom system prompt used for commit message generation.
configurationConfigurable.section.commitMessage.title=Commit Message Template
configurationConfigurable.section.commitMessage.systemPromptField.label=Prompt template:
configurationConfigurable.section.inlineCompletion.title=Inline Completion
configurationConfigurable.section.inlineCompletion.systemPromptField.label=Prompt:
configurationConfigurable.section.inlineCompletion.systemPromptField.comment=Custom system prompt used for inline code generation (Fill in the Middle (FIM) template).<br/>The {pre}, {suf} and {mid} are replaced depending on the used Model's FIM template.
@ -179,7 +181,7 @@ validation.error.mustBeGreaterThanZero=Value must be greater than 0
checkForUpdatesTask.title=Checking for CodeGPT update...
checkForUpdatesTask.notification.message=An update for CodeGPT is available.
checkForUpdatesTask.notification.installButton=Install update
llamaServerAgent.buildingProject.description=Building llama.cpp...
llamaServerAgent.buildingProject.description=Building server...
llamaServerAgent.serverBootup.description=Booting up server...
notification.compilationError.description=CodeGPT has detected a compilation error. Would you like assistance in resolving it?
notification.compilationError.okLabel=Resolve errors
@ -198,6 +200,7 @@ action.attachImage=Attach Image
action.attachImageDescription=Attach an image
imageFileChooser.title=Select Image
imageAccordion.title=Attached image
shared.chatCompletions=Chat Completions
shared.codeCompletions=Code Completions
codeCompletionsForm.enableFeatureText=Enable code completions
codeCompletionsForm.maxTokensLabel=Max tokens:

View file

@ -3,10 +3,14 @@ package ee.carlrobert.codegpt.completions
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.ALPACA
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.CHAT_ML
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.LLAMA_3
import ee.carlrobert.codegpt.completions.llama.PromptTemplate.TORA
import ee.carlrobert.codegpt.conversations.message.Message
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.NullAndEmptySource
import org.junit.jupiter.params.provider.ValueSource
class PromptTemplateTest {
@ -34,6 +38,72 @@ class PromptTemplateTest {
""".trimIndent())
}
@Test
fun shouldBuildLlama3PromptWithoutHistory() {
val prompt = LLAMA_3.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, listOf())
assertThat(prompt).isEqualTo("""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
TEST_SYSTEM_PROMPT<|eot_id|><|start_header_id|>user<|end_header_id|>
TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent()
)
}
@ParameterizedTest
@NullAndEmptySource
@ValueSource(strings = [" ", "\t", "\n"])
fun shouldBuildLlama3PromptWithoutHistorySkippingBlankSystemPrompt(systemPrompt: String?) {
val prompt = LLAMA_3.buildPrompt(systemPrompt, USER_PROMPT, listOf())
assertThat(prompt).isEqualTo("""
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent()
)
}
@Test
fun shouldBuildLlama3PromptWithHistory() {
val prompt = LLAMA_3.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)
assertThat(prompt).isEqualTo("""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
TEST_SYSTEM_PROMPT<|eot_id|><|start_header_id|>user<|end_header_id|>
TEST_PREV_PROMPT_1<|eot_id|><|start_header_id|>assistant<|end_header_id|>
TEST_PREV_RESPONSE_1<|eot_id|><|start_header_id|>user<|end_header_id|>
TEST_PREV_PROMPT_2<|eot_id|><|start_header_id|>assistant<|end_header_id|>
TEST_PREV_RESPONSE_2<|eot_id|><|start_header_id|>user<|end_header_id|>
TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent())
}
@ParameterizedTest
@NullAndEmptySource
@ValueSource(strings = [" ", "\t", "\n"])
fun shouldBuildLlama3PromptWithHistorySkippingBlankSystemPrompt(systemPrompt: String?) {
val prompt = LLAMA_3.buildPrompt(systemPrompt, USER_PROMPT, HISTORY)
assertThat(prompt).isEqualTo("""
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
TEST_PREV_PROMPT_1<|eot_id|><|start_header_id|>assistant<|end_header_id|>
TEST_PREV_RESPONSE_1<|eot_id|><|start_header_id|>user<|end_header_id|>
TEST_PREV_PROMPT_2<|eot_id|><|start_header_id|>assistant<|end_header_id|>
TEST_PREV_RESPONSE_2<|eot_id|><|start_header_id|>user<|end_header_id|>
TEST_USER_PROMPT<|eot_id|><|start_header_id|>assistant<|end_header_id|>""".trimIndent())
}
@Test
fun shouldBuildAlpacaPromptWithHistory() {
val prompt = ALPACA.buildPrompt(SYSTEM_PROMPT, USER_PROMPT, HISTORY)

View file

@ -0,0 +1,30 @@
package ee.carlrobert.codegpt.settings.configuration
import com.intellij.openapi.components.service
import git4idea.commands.GitCommand
import org.assertj.core.api.Assertions.assertThat
import testsupport.VcsTestCase
import java.time.LocalDate
class CommitMessageTemplateTest : VcsTestCase() {
fun `test commit message system prompt construction`() {
git(GitCommand.INIT)
git(GitCommand.CHECKOUT, listOf("-b", "feature/my-cool-feature"))
registerRepository()
service<ConfigurationSettings>().state.commitMessagePrompt = buildString {
append("Branch: {BRANCH_NAME}\n")
append("Date: {DATE_ISO_8601}")
}
val systemPrompt = project.service<CommitMessageTemplate>().getSystemPrompt()
assertThat(systemPrompt).isEqualTo(
buildString {
append("Branch: feature/my-cool-feature\n")
append("Date: ${LocalDate.now()}")
}
)
}
}

View file

@ -25,6 +25,17 @@ class GeneralSettingsTest : BasePlatformTestCase() {
assertThat(openAISettings.model).isEqualTo("gpt-4")
}
fun testCustomOpenAISettingsSync() {
val conversation = Conversation()
conversation.clientCode = "custom.openai.chat.completion"
val settings = GeneralSettings.getInstance()
settings.state.selectedService = ServiceType.OPENAI
settings.sync(conversation)
assertThat(settings.state.selectedService).isEqualTo(ServiceType.CUSTOM_OPENAI)
}
fun testAzureSettingsSync() {
val settings = GeneralSettings.getInstance()
val conversation = Conversation()

View file

@ -0,0 +1,67 @@
package ee.carlrobert.codegpt.telemetry.core.service.segment
import com.google.gson.Gson
import com.google.gson.JsonSyntaxException
import com.intellij.util.io.write
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import kotlin.io.path.Path
import kotlin.io.path.createTempFile
import kotlin.io.path.readText
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNull
private const val NOT_JSON = "}NOT]:JSON{"
class IdentifyTraitsPersistenceTest {
private val gson = Gson()
private val persistence = IdentifyTraitsPersistence.INSTANCE
private val identifyTraits = IdentifyTraits("locale", "timezone", "os", "version", "distribution")
@BeforeEach
fun setUp() {
persistence.identifyTraits = null
IdentifyTraitsPersistence.FILE = createTempFile()
}
@Test
fun `get returns null when file does not exist`() {
IdentifyTraitsPersistence.FILE = Path(" ")
assertNull(persistence.get())
}
@Test
fun `get throws JsonSyntaxException when file contains malformed JSON`() {
IdentifyTraitsPersistence.FILE.write(NOT_JSON)
assertFailsWith<JsonSyntaxException> {
persistence.get()
}
}
@Test
fun `set saves the event to the file overwriting it`() {
IdentifyTraitsPersistence.FILE.write(NOT_JSON)
persistence.set(identifyTraits)
assertEquals(IdentifyTraitsPersistence.FILE.readText(), gson.toJson(identifyTraits))
}
@Test
fun `set saves the event to the file when file does not exist`() {
persistence.set(identifyTraits)
assertEquals(IdentifyTraitsPersistence.FILE.readText(), gson.toJson(identifyTraits))
}
@Test
fun `get returns the deserialized event`() {
IdentifyTraitsPersistence.FILE.write(gson.toJson(identifyTraits))
assertEquals(identifyTraits, persistence.get())
}
@Test
fun `set throws IOException when file cannot be written and returns false`() {
IdentifyTraitsPersistence.FILE = IdentifyTraitsPersistence.FILE.resolve(" xyz ")
assertEquals(persistence.set(identifyTraits), false)
}
}

View file

@ -0,0 +1,54 @@
package testsupport
import com.intellij.openapi.components.service
import com.intellij.openapi.vcs.ProjectLevelVcsManager
import com.intellij.openapi.vcs.VcsDirectoryMapping
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.testFramework.HeavyPlatformTestCase
import git4idea.GitVcs
import git4idea.commands.Git
import git4idea.commands.GitCommand
import git4idea.commands.GitLineHandler
import git4idea.repo.GitRepository
import git4idea.repo.GitRepositoryManager
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.junit.Assert
import java.nio.file.Files
import java.nio.file.Path
open class VcsTestCase : HeavyPlatformTestCase() {
private lateinit var projectDir: Path
@Throws(Exception::class)
override fun setUp() {
super.setUp()
projectDir = tempDir.createDir()
}
fun git(command: GitCommand, parameters: List<String> = emptyList()) {
val checkoutHandler = GitLineHandler(project, projectDir.toFile(), command)
checkoutHandler.addParameters(parameters)
service<Git>().runCommand(checkoutHandler).throwOnError()
}
fun registerRepository(): GitRepository =
ProjectLevelVcsManager.getInstance(project).run {
directoryMappings = listOf(VcsDirectoryMapping(projectDir.toString(), GitVcs.NAME))
Files.createDirectories(projectDir)
Assert.assertFalse(
"There are no VCS roots. Active VCSs: $allActiveVcss",
allVcsRoots.isEmpty()
)
val file = LocalFileSystem.getInstance().refreshAndFindFileByNioFile(projectDir)
runBlocking(Dispatchers.IO) {
val repository = project.service<GitRepositoryManager>().getRepositoryForRoot(file)
assertThat(repository).describedAs("Couldn't find repository for root $projectDir")
.isNotNull()
repository!!
}
}
}