feat: support multi-line code completions

This commit is contained in:
Carl-Robert Linnupuu 2024-11-30 22:54:12 +00:00
parent 9251b12c9b
commit a47ffd62e4
23 changed files with 646 additions and 345 deletions

View file

@ -8,33 +8,33 @@ dependencies {
implementation("io.github.bonede:tree-sitter-elixir:0.2.0")
implementation("io.github.bonede:tree-sitter-dockerfile:0.2.0")
implementation("io.github.bonede:tree-sitter-dart:master-a")
implementation("io.github.bonede:tree-sitter-css:0.21.0")
implementation("io.github.bonede:tree-sitter-cpp:0.22.0a")
implementation("io.github.bonede:tree-sitter-c-sharp:0.20.0a")
implementation("io.github.bonede:tree-sitter-css:0.23.1")
implementation("io.github.bonede:tree-sitter-cpp:0.23.4")
implementation("io.github.bonede:tree-sitter-c-sharp:0.23.1")
implementation("io.github.bonede:tree-sitter-fortran:master-a")
implementation("io.github.bonede:tree-sitter-gitattributes:0.1.6")
implementation("io.github.bonede:tree-sitter-go:0.21.0a")
implementation("io.github.bonede:tree-sitter-go:0.23.3")
implementation("io.github.bonede:tree-sitter-graphql:master-a")
implementation("io.github.bonede:tree-sitter-html:0.20.3")
implementation("io.github.bonede:tree-sitter-javascript:0.21.2")
implementation("io.github.bonede:tree-sitter-json:0.21.0a")
implementation("io.github.bonede:tree-sitter-kotlin:0.3.6")
implementation("io.github.bonede:tree-sitter-html:0.23.2")
implementation("io.github.bonede:tree-sitter-javascript:0.23.1")
implementation("io.github.bonede:tree-sitter-json:0.23.0")
implementation("io.github.bonede:tree-sitter-kotlin:0.3.8.1")
implementation("io.github.bonede:tree-sitter-latex:0.3.0a")
implementation("io.github.bonede:tree-sitter-lua:2.1.3a")
implementation("io.github.bonede:tree-sitter-m68k:0.2.7a")
implementation("io.github.bonede:tree-sitter-markdown:0.7.1a")
implementation("io.github.bonede:tree-sitter-objc:main-a")
implementation("io.github.bonede:tree-sitter-perl:1.1.0")
implementation("io.github.bonede:tree-sitter-ruby:0.21.0")
implementation("io.github.bonede:tree-sitter-rust:0.21.2")
implementation("io.github.bonede:tree-sitter-scala:0.21.0a")
implementation("io.github.bonede:tree-sitter-ruby:0.23.1")
implementation("io.github.bonede:tree-sitter-rust:0.23.1")
implementation("io.github.bonede:tree-sitter-scala:0.23.3")
implementation("io.github.bonede:tree-sitter-scss:1.0.0a")
implementation("io.github.bonede:tree-sitter-svelte:0.11.0a")
implementation("io.github.bonede:tree-sitter-swift:0.5.0")
implementation("io.github.bonede:tree-sitter-yaml:0.5.0a")
implementation("io.github.bonede:tree-sitter-java:0.21.0a")
implementation("io.github.bonede:tree-sitter-python:0.21.0a")
implementation("io.github.bonede:tree-sitter-php:0.22.4")
implementation("io.github.bonede:tree-sitter-java:0.23.4")
implementation("io.github.bonede:tree-sitter-python:0.23.4")
implementation("io.github.bonede:tree-sitter-php:0.23.11")
implementation("io.github.bonede:tree-sitter-typescript:0.21.1")
implementation("io.github.bonede:tree-sitter-query:0.3.0")
}

View file

@ -1,44 +1,116 @@
package ee.carlrobert.codegpt.treesitter;
import java.nio.charset.StandardCharsets;
import org.treesitter.TSInputEdit;
import org.treesitter.TSLanguage;
import org.treesitter.TSParser;
import org.treesitter.TSPoint;
import org.treesitter.TSTree;
public class CodeCompletionParser {
protected final TSLanguage language;
private final TSParser parser;
public CodeCompletionParser(TSLanguage language) {
this.language = language;
parser = new TSParser();
parser.setLanguage(language);
}
public String parse(String prefix, String suffix, String output) {
var result = new StringBuilder(output);
String input = prefix + result + suffix;
TSTree currentTree = parser.parseString(null, input);
while (!result.isEmpty()) {
if (containsError(prefix + result + suffix)) {
if (containsError(currentTree)) {
int deletionIndex = prefix.length() + result.length() - 1;
Position pos = getPosition(input, deletionIndex);
int startByte = pos.byteOffset;
int oldEndByte = startByte + getByteLength(result.substring(result.length() - 1));
TSPoint startPoint = pos.point;
TSPoint oldEndPoint = computeOldEndPoint(startPoint, result.charAt(result.length() - 1));
currentTree.edit(
new TSInputEdit(startByte, oldEndByte, startByte, startPoint, oldEndPoint, startPoint));
result.deleteCharAt(result.length() - 1);
if (result.length() > 1 && result.charAt(result.length() - 1) == '{') {
long bracketCount = result.chars().filter(ch -> ch == '{').count();
if (bracketCount == 1) {
var newTree = parser.parseString(currentTree, prefix + result + "}" + suffix);
var treeString = newTree.getRootNode().toString();
if (!treeString.contains("ERROR")) {
return result + "}";
}
}
}
input = prefix + result + suffix;
currentTree = parser.parseString(currentTree, input);
} else {
return result.toString();
}
}
if (output.contains("\n")) {
return parse(prefix, suffix, output.substring(0, output.indexOf("\n")));
var finalResult = output.substring(0, output.indexOf("\n"));
if (finalResult.charAt(finalResult.length() - 1) == '{') {
return finalResult + "}";
}
return finalResult;
}
return output;
}
private boolean containsError(String input) {
var treeString = getTree(input).getRootNode().toString();
private boolean containsError(TSTree tree) {
var treeString = tree.getRootNode().toString();
return treeString.contains("ERROR")
|| treeString.contains("MISSING \"}\"")
|| treeString.contains("MISSING \")\"");
}
private TSTree getTree(String input) {
var parser = new TSParser();
parser.setLanguage(language);
return parser.parseString(null, input);
private Position getPosition(String input, int index) {
int row = 0;
int col = 0;
int byteOffset = 0;
for (int i = 0; i < index; i++) {
char c = input.charAt(i);
int charByteLength = getByteLength(String.valueOf(c));
byteOffset += charByteLength;
if (c == '\n') {
row++;
col = 0;
} else {
col++;
}
}
return new Position(new TSPoint(row, col), byteOffset);
}
private int getByteLength(String str) {
return str.getBytes(StandardCharsets.UTF_8).length;
}
private TSPoint computeOldEndPoint(TSPoint startPoint, char deletedChar) {
int row = startPoint.getRow();
int col = startPoint.getColumn();
if (deletedChar == '\n') {
row++;
col = 0;
} else {
col++;
}
return new TSPoint(row, col);
}
private record Position(TSPoint point, int byteOffset) {
}
}

View file

@ -38,9 +38,7 @@ public class CodeCompletionParserTest {
return 10;
}
}""";
var output = """
prevNumber) {
if() {""";
var output = "prevNumber);";
var parsedResponse = CodeCompletionParserFactory
.getParserForFileExtension("java")

View file

@ -12,9 +12,9 @@ jsoup = "1.17.2"
jtokkit = "1.1.0"
junit = "5.11.0"
kotlin = "2.0.0"
llm-client = "0.8.28"
llm-client = "0.8.29"
okio = "3.9.0"
tree-sitter = "0.22.6a"
tree-sitter = "0.24.4"
[libraries]
analytics = { module = "com.rudderstack.sdk.java.analytics:analytics", version.ref = "analytics" }

View file

@ -25,11 +25,9 @@ public class ConfigurationComponent {
private final JBCheckBox checkForNewScreenshotsCheckBox;
private final JBCheckBox methodNameGenerationCheckBox;
private final JBCheckBox autoFormattingCheckBox;
private final JBCheckBox autocompletionPostProcessingCheckBox;
private final JBCheckBox autocompletionContextAwareCheckBox;
private final JBCheckBox autocompletionGitContextCheckBox;
private final IntegerField maxTokensField;
private final JBTextField temperatureField;
private final CodeCompletionConfigurationForm codeCompletionForm;
public ConfigurationComponent(
Disposable parentDisposable,
@ -72,31 +70,21 @@ public class ConfigurationComponent {
autoFormattingCheckBox = new JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.autoFormatting.label"),
configuration.getAutoFormattingEnabled());
autocompletionPostProcessingCheckBox = new JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.autocompletionPostProcessing.label"),
configuration.getAutocompletionPostProcessingEnabled()
);
autocompletionContextAwareCheckBox = new JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.autocompletionContextAwareCheckBox.label"),
configuration.getAutocompletionContextAwareEnabled()
);
autocompletionGitContextCheckBox = new JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.autocompletionGitContextCheckBox.label"),
configuration.getAutocompletionGitContextEnabled()
);
codeCompletionForm = new CodeCompletionConfigurationForm();
mainPanel = FormBuilder.createFormBuilder()
.addComponent(checkForPluginUpdatesCheckBox)
.addComponent(checkForNewScreenshotsCheckBox)
.addComponent(methodNameGenerationCheckBox)
.addComponent(autoFormattingCheckBox)
.addComponent(autocompletionPostProcessingCheckBox)
.addComponent(autocompletionContextAwareCheckBox)
.addComponent(autocompletionGitContextCheckBox)
.addVerticalGap(4)
.addComponent(new TitledSeparator(
CodeGPTBundle.get("configurationConfigurable.section.assistant.title")))
.addComponent(createAssistantConfigurationForm())
.addComponent(new TitledSeparator(
CodeGPTBundle.get("configurationConfigurable.section.codeCompletion.title")))
.addComponent(codeCompletionForm.createPanel())
.addComponentFillVertically(new JPanel(), 0)
.getPanel();
}
@ -113,9 +101,7 @@ public class ConfigurationComponent {
state.setCheckForNewScreenshots(checkForNewScreenshotsCheckBox.isSelected());
state.setMethodNameGenerationEnabled(methodNameGenerationCheckBox.isSelected());
state.setAutoFormattingEnabled(autoFormattingCheckBox.isSelected());
state.setAutocompletionPostProcessingEnabled(autocompletionPostProcessingCheckBox.isSelected());
state.setAutocompletionContextAwareEnabled(autocompletionContextAwareCheckBox.isSelected());
state.setAutocompletionGitContextEnabled(autocompletionGitContextCheckBox.isSelected());
state.setCodeCompletionSettings(codeCompletionForm.getFormState());
return state;
}
@ -127,13 +113,7 @@ public class ConfigurationComponent {
checkForNewScreenshotsCheckBox.setSelected(configuration.getCheckForNewScreenshots());
methodNameGenerationCheckBox.setSelected(configuration.getMethodNameGenerationEnabled());
autoFormattingCheckBox.setSelected(configuration.getAutoFormattingEnabled());
autocompletionPostProcessingCheckBox.setSelected(
configuration.getAutocompletionPostProcessingEnabled());
autocompletionContextAwareCheckBox.setSelected(
configuration.getAutocompletionContextAwareEnabled());
autocompletionGitContextCheckBox.setSelected(
configuration.getAutocompletionGitContextEnabled()
);
codeCompletionForm.resetForm(configuration.getCodeCompletionSettings());
}
// Formatted keys are not referenced in the messages bundle file

View file

@ -1,5 +1,6 @@
package ee.carlrobert.codegpt.settings.service.llama.form;
import ee.carlrobert.codegpt.codecompletions.CompletionType;
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate;
import ee.carlrobert.codegpt.codecompletions.InfillRequest;
@ -17,6 +18,8 @@ public class InfillPromptTemplatePanel extends BasePromptTemplatePanel<InfillPro
@Override
protected String buildPromptDescription(InfillPromptTemplate template) {
return template.buildPrompt(new InfillRequest.Builder("PREFIX", "SUFFIX", 0).build());
return template.buildPrompt(new InfillRequest
.Builder("PREFIX", "SUFFIX", 0, CompletionType.MULTI_LINE)
.build());
}
}

View file

@ -1,21 +1,25 @@
package ee.carlrobert.codegpt.codecompletions
import com.intellij.codeInsight.inline.completion.InlineCompletionRequest
import com.intellij.notification.NotificationType
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.application.runInEdt
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.thisLogger
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.util.TextRange
import ee.carlrobert.codegpt.CodeGPTKeys.IS_FETCHING_COMPLETION
import ee.carlrobert.codegpt.CodeGPTKeys.REMAINING_EDITOR_COMPLETION
import ee.carlrobert.codegpt.codecompletions.CompletionUtil.formatCompletion
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
import ee.carlrobert.codegpt.ui.OverlayUtil.showNotification
import ee.carlrobert.codegpt.util.StringUtil
import ee.carlrobert.codegpt.util.EditorUtil.adjustWhitespaces
import ee.carlrobert.llm.client.openai.completion.ErrorDetails
import ee.carlrobert.llm.completion.CompletionEventListener
import okhttp3.sse.EventSource
import kotlin.math.min
abstract class CodeCompletionEventListener(
private val editor: Editor
@ -25,51 +29,19 @@ abstract class CodeCompletionEventListener(
private val logger = thisLogger()
}
private var isFirstLine = true
private val currentLineBuffer = StringBuilder()
private val incomingTextBuffer = StringBuilder()
open fun onLineReceived(completionLine: String) {}
abstract fun handleCompleted(messageBuilder: StringBuilder)
override fun onOpen() {
setLoading(true)
REMAINING_EDITOR_COMPLETION.set(editor, "")
}
override fun onMessage(message: String, eventSource: EventSource) {
incomingTextBuffer.append(message)
while (incomingTextBuffer.contains("\n")) {
val lineEndIndex = incomingTextBuffer.indexOf("\n")
val line = incomingTextBuffer.substring(0, lineEndIndex) + '\n'
processCompletionLine(line)
incomingTextBuffer.delete(0, lineEndIndex + 1)
}
}
private fun processCompletionLine(line: String) {
currentLineBuffer.append(line)
if (currentLineBuffer.trim().isNotEmpty()) {
val completionText = if (isFirstLine) {
line.adjustWhitespaces().also {
isFirstLine = false
onLineReceived(it)
}
} else {
currentLineBuffer.toString()
}
appendRemainingCompletion(completionText)
currentLineBuffer.clear()
}
}
override fun onComplete(messageBuilder: StringBuilder) {
setLoading(false)
handleCompleted(messageBuilder)
}
override fun onCancelled(messageBuilder: StringBuilder) {
setLoading(false)
handleCompleted(messageBuilder)
}
@ -88,45 +60,134 @@ abstract class CodeCompletionEventListener(
setLoading(false)
}
private fun String.adjustWhitespaces(): String {
val adjustedLine = runReadAction {
val lineNumber = editor.document.getLineNumber(editor.caretModel.offset)
val editorLine = editor.document.getText(
TextRange(
editor.document.getLineStartOffset(lineNumber),
editor.document.getLineEndOffset(lineNumber)
)
)
private fun setLoading(loading: Boolean) {
IS_FETCHING_COMPLETION.set(editor, loading)
editor.project?.messageBus
?.syncPublisher(CodeCompletionProgressNotifier.CODE_COMPLETION_PROGRESS_TOPIC)
?.loading(loading)
}
}
StringUtil.adjustWhitespace(this, editorLine)
class CodeCompletionMultiLineEventListener(
private val request: InlineCompletionRequest,
private val onCompletionReceived: (String) -> Unit
) : CodeCompletionEventListener(request.editor) {
override fun handleCompleted(messageBuilder: StringBuilder) {
runInEdt {
onCompletionReceived(runWriteAction {
messageBuilder.toString().formatCompletion(request)
})
}
}
}
return if (adjustedLine.length != this.length) adjustedLine else this
class CodeCompletionSingleLineEventListener(
private val editor: Editor,
private val infillRequest: InfillRequest,
private val onSend: (element: CodeCompletionTextElement) -> Unit,
) : CodeCompletionEventListener(editor) {
private var isFirstLine = true
private val currentLineBuffer = StringBuilder()
private val incomingTextBuffer = StringBuilder()
override fun onMessage(message: String, eventSource: EventSource) {
incomingTextBuffer.append(message)
while (incomingTextBuffer.contains("\n")) {
val lineEndIndex = incomingTextBuffer.indexOf("\n")
val line = incomingTextBuffer.substring(0, lineEndIndex) + '\n'
processCompletionLine(line)
incomingTextBuffer.delete(0, lineEndIndex + 1)
}
}
private fun handleCompleted(messageBuilder: StringBuilder) {
setLoading(false)
override fun handleCompleted(messageBuilder: StringBuilder) {
if (incomingTextBuffer.isNotEmpty()) {
appendRemainingCompletion(incomingTextBuffer.toString())
}
if (isFirstLine) {
val completionLine = messageBuilder.toString().adjustWhitespaces()
val completionLine = messageBuilder.toString().adjustWhitespaces(editor)
REMAINING_EDITOR_COMPLETION.set(editor, completionLine)
onLineReceived(completionLine)
}
}
private fun processCompletionLine(line: String) {
currentLineBuffer.append(line)
if (currentLineBuffer.trim().isNotEmpty()) {
val completionText = if (isFirstLine) {
line.adjustWhitespaces(editor).also {
isFirstLine = false
onLineReceived(it)
}
} else {
currentLineBuffer.toString()
}
appendRemainingCompletion(completionText)
currentLineBuffer.clear()
}
}
private fun onLineReceived(completionLine: String) {
runInEdt {
var editorLineSuffix = editor.getLineSuffixAfterCaret()
if (editorLineSuffix.isBlank()) {
onSend(
CodeCompletionTextElement(
completionLine,
infillRequest.caretOffset,
TextRange.from(infillRequest.caretOffset, completionLine.length),
)
)
} else {
var caretShift = 0
// TODO: Handle other scenarios
val processedCompletion =
if (completionLine.startsWith(editorLineSuffix.first())) {
caretShift++
editorLineSuffix = editorLineSuffix.substring(1)
completionLine.substring(1)
} else {
completionLine
}
val completionWithRemovedSuffix =
processedCompletion.removeSuffix(editorLineSuffix)
onSend(
CodeCompletionTextElement(
completionWithRemovedSuffix,
infillRequest.caretOffset + caretShift,
TextRange.from(
infillRequest.caretOffset + caretShift,
completionWithRemovedSuffix.length
),
caretShift,
completionLine
)
)
}
}
}
private fun appendRemainingCompletion(text: String) {
val previousRemainingText = REMAINING_EDITOR_COMPLETION.get(editor) ?: ""
REMAINING_EDITOR_COMPLETION.set(editor, previousRemainingText + text)
}
private fun setLoading(loading: Boolean) {
IS_FETCHING_COMPLETION.set(editor, loading)
editor.project?.messageBus
?.syncPublisher(CodeCompletionProgressNotifier.CODE_COMPLETION_PROGRESS_TOPIC)
?.loading(loading)
private fun Editor.getLineSuffixAfterCaret(): String {
val lineEndOffset = document.getLineEndOffset(document.getLineNumber(caretModel.offset))
return document.getText(
TextRange(
caretModel.offset,
min(lineEndOffset + 1, document.textLength)
)
)
}
}

View file

@ -19,6 +19,6 @@ class CodeCompletionProviderPresentation : InlineCompletionProviderPresentation
} else {
"CodeGPT"
}
return JBLabel(text, Icons.Sparkle, SwingConstants.LEADING).withFont(JBFont.small())
return JBLabel(text, Icons.DefaultSmall, SwingConstants.LEADING).withFont(JBFont.small())
}
}

View file

@ -3,11 +3,10 @@ package ee.carlrobert.codegpt.codecompletions
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.completions.llama.LlamaModel
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey
import ee.carlrobert.codegpt.credentials.CredentialsStore.getCredential
import ee.carlrobert.codegpt.settings.configuration.Placeholder.*
import ee.carlrobert.codegpt.settings.Placeholder.*
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
@ -25,30 +24,29 @@ import java.nio.charset.StandardCharsets
object CodeCompletionRequestFactory {
private const val MAX_TOKENS = 128
private const val MAX_TOKENS = 80
@JvmStatic
fun buildCodeGPTRequest(details: InfillRequest): CodeCompletionRequest {
val settings = service<CodeGPTServiceSettings>().state.codeCompletionSettings
return CodeCompletionRequest.Builder()
.setModel(settings.model)
.setModel(service<CodeGPTServiceSettings>().state.codeCompletionSettings.model)
.setPrefix(details.prefix)
.setSuffix(details.suffix)
.setFileExtension(details.fileDetails?.fileExtension)
.setFileContent(details.fileDetails?.fileContent)
.setStagedDiff(details.vcsDetails?.stagedDiff)
.setUnstagedDiff(details.vcsDetails?.unstagedDiff)
.setStop(details.stopTokens.ifEmpty { null })
.build()
}
@JvmStatic
fun buildOpenAIRequest(details: InfillRequest): OpenAITextCompletionRequest {
val (prefix, suffix) = getCompletionContext(details)
return OpenAITextCompletionRequest.Builder(prefix)
.setSuffix(suffix)
return OpenAITextCompletionRequest.Builder(details.prefix)
.setSuffix(details.suffix)
.setStream(true)
.setMaxTokens(MAX_TOKENS)
.setTemperature(0.4)
.setTemperature(0.0)
.setPresencePenalty(0.0)
.setStop(details.stopTokens.ifEmpty { null })
.build()
}
@ -104,16 +102,26 @@ object CodeCompletionRequestFactory {
val settings = LlamaSettings.getCurrentState()
val promptTemplate = getLlamaInfillPromptTemplate(settings)
val prompt = promptTemplate.buildPrompt(details)
val stopTokens = buildList {
if (promptTemplate.stopTokens != null) addAll(promptTemplate.stopTokens)
if (details.stopTokens.isNotEmpty()) addAll(details.stopTokens)
}.ifEmpty { null }
return LlamaCompletionRequest.Builder(prompt)
.setN_predict(MAX_TOKENS)
.setStream(true)
.setTemperature(0.4)
.setStop(promptTemplate.stopTokens)
.setTemperature(0.0)
.setStop(stopTokens)
.build()
}
fun buildOllamaRequest(details: InfillRequest): OllamaCompletionRequest {
val settings = service<OllamaSettings>().state
val stopTokens = buildList {
if (settings.fimTemplate.stopTokens != null) addAll(settings.fimTemplate.stopTokens!!)
if (details.stopTokens.isNotEmpty()) addAll(details.stopTokens)
}.ifEmpty { null }
return OllamaCompletionRequest.Builder(
settings.model,
settings.fimTemplate.buildPrompt(details)
@ -121,7 +129,7 @@ object CodeCompletionRequestFactory {
.setStream(true)
.setOptions(
OllamaParameters.Builder()
.stop(settings.fimTemplate.stopTokens)
.stop(stopTokens)
.numPredict(MAX_TOKENS)
.temperature(0.4)
.build()
@ -147,41 +155,15 @@ object CodeCompletionRequestFactory {
): Any {
if (value !is String) return value
val (prefix, suffix) = getCompletionContext(details)
return when (value) {
FIM_PROMPT.code -> template.buildPrompt(details)
PREFIX.code -> prefix
SUFFIX.code -> suffix
PREFIX.code -> details.prefix
SUFFIX.code -> details.suffix
else -> {
return value.takeIf { it.contains(PREFIX.code) || it.contains(SUFFIX.code) }
?.replace(PREFIX.code, prefix)
?.replace(SUFFIX.code, suffix) ?: value
?.replace(PREFIX.code, details.prefix)
?.replace(SUFFIX.code, details.suffix) ?: value
}
}
}
private fun getCompletionContext(request: InfillRequest): Pair<String, String> {
val encodingManager = EncodingManager.getInstance()
val truncatedPrefix = encodingManager.truncateText(request.prefix, 128, false)
val truncatedSuffix = encodingManager.truncateText(request.suffix, 128, true)
val vcsDetails = request.vcsDetails ?: return truncatedPrefix to truncatedSuffix
val stagedDiff = if (vcsDetails.stagedDiff != null)
encodingManager.truncateText(vcsDetails.stagedDiff, 200, true)
else
""
val unstagedDiff = if (vcsDetails.unstagedDiff != null)
encodingManager.truncateText(vcsDetails.unstagedDiff, 200, true)
else
""
val prompt: String = if (vcsDetails.stagedDiff != null)
"""
${"/*\n${stagedDiff + unstagedDiff}\n\n*/"}
$truncatedPrefix
""".trimIndent()
else
truncatedPrefix
return prompt to truncatedSuffix
}
}
}

View file

@ -53,20 +53,20 @@ class CodeCompletionService {
}
fun getCodeCompletionAsync(
requestDetails: InfillRequest,
infillRequest: InfillRequest,
eventListener: CompletionEventListener<String>
): EventSource =
when (val selectedService = GeneralSettings.getSelectedService()) {
CODEGPT -> CompletionClientProvider.getCodeGPTClient()
.getCodeCompletionAsync(buildCodeGPTRequest(requestDetails), eventListener)
.getCodeCompletionAsync(buildCodeGPTRequest(infillRequest), eventListener)
OPENAI -> CompletionClientProvider.getOpenAIClient()
.getCompletionAsync(buildOpenAIRequest(requestDetails), eventListener)
.getCompletionAsync(buildOpenAIRequest(infillRequest), eventListener)
CUSTOM_OPENAI -> createFactory(
CompletionClientProvider.getDefaultClientBuilder().build()
).newEventSource(
buildCustomRequest(requestDetails),
buildCustomRequest(infillRequest),
if (service<CustomServiceSettings>().state.codeCompletionSettings.parseResponseAsChatCompletions) {
OpenAIChatCompletionEventSourceListener(eventListener)
} else {
@ -75,10 +75,10 @@ class CodeCompletionService {
)
OLLAMA -> CompletionClientProvider.getOllamaClient()
.getCompletionAsync(buildOllamaRequest(requestDetails), eventListener)
.getCompletionAsync(buildOllamaRequest(infillRequest), eventListener)
LLAMA_CPP -> CompletionClientProvider.getLlamaClient()
.getChatCompletionAsync(buildLlamaRequest(requestDetails), eventListener)
.getChatCompletionAsync(buildLlamaRequest(infillRequest), eventListener)
else -> throw IllegalArgumentException("Code completion not supported for ${selectedService.name}")
}

View file

@ -0,0 +1,100 @@
package ee.carlrobert.codegpt.codecompletions
import ai.grazie.nlp.utils.takeLastWhitespaces
import ai.grazie.nlp.utils.takeWhitespaces
import com.intellij.codeInsight.inline.completion.InlineCompletionRequest
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.thisLogger
import com.intellij.openapi.editor.Document
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.util.TextRange
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiFileFactory
import com.intellij.psi.codeStyle.CodeStyleManager
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.treesitter.CodeCompletionParserFactory
import kotlin.math.min
object CompletionUtil {
val logger = thisLogger()
fun String.formatCompletion(request: InlineCompletionRequest): String {
try {
val editor = request.editor
val project = editor.project ?: return this
val document = request.document
val caretOffset = editor.caretModel.offset
val textBeforeCompletion = document.text.substring(0, caretOffset)
val textAfterCompletion = document.text.substring(caretOffset)
val adjustedText = if (
takeWhitespaces().isNotEmpty()
&& textBeforeCompletion.takeLastWhitespaces().isNotEmpty()
) removePrefix(takeWhitespaces())
else this
if (adjustedText.lines().size == 1) {
return adjustedText
}
val originalFile = service<FileDocumentManager>().getFile(document) ?: return ""
val tempFile = project.service<PsiFileFactory>().createFileFromText(
"temp.${originalFile.extension}",
originalFile.fileType,
buildString {
append(textBeforeCompletion)
append(adjustedText)
append(textAfterCompletion)
}
)
project.service<CodeStyleManager>()
.adjustLineIndent(tempFile, TextRange.from(caretOffset, adjustedText.length))
val formattedCompletion =
getFormattedCompletion(adjustedText, tempFile, document, editor)
return if (service<ConfigurationSettings>().state.codeCompletionSettings.treeSitterProcessingEnabled) {
CodeCompletionParserFactory.getParserForFileExtension(originalFile.extension)
.parse(textBeforeCompletion, textAfterCompletion, formattedCompletion)
.trimEnd()
} else {
formattedCompletion
}
} catch (e: Exception) {
logger.error("Failed to format completion output", e)
return this
}
}
private fun getFormattedCompletion(
completionText: String,
tempFile: PsiFile,
document: Document,
editor: Editor,
): String {
val formattedText = StringBuilder()
val tempFileDocument =
FileDocumentManager.getInstance().getDocument(tempFile.virtualFile)
?: return completionText
val currentCaretLine = editor.caretModel.logicalPosition.line
var linePosition = currentCaretLine
for (i in completionText.lines().indices) {
val minPosition = min(linePosition, tempFileDocument.lineCount - 1)
val range = TextRange(
tempFileDocument.getLineStartOffset(minPosition),
tempFileDocument.getLineEndOffset(minPosition)
)
formattedText.append(tempFileDocument.getText(range)).append("\n")
linePosition++
}
val prefixToRemove = document.getText(
TextRange(document.getLineStartOffset(currentCaretLine), editor.caretModel.offset)
)
return formattedText.removePrefix(prefixToRemove).trimEnd().toString()
}
}

View file

@ -2,15 +2,16 @@ package ee.carlrobert.codegpt.codecompletions
import com.intellij.codeInsight.inline.completion.*
import com.intellij.codeInsight.inline.completion.elements.InlineCompletionElement
import com.intellij.codeInsight.inline.completion.elements.InlineCompletionGrayTextElement
import com.intellij.codeInsight.lookup.LookupManager
import com.intellij.openapi.application.runInEdt
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.thisLogger
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.TextRange
import ee.carlrobert.codegpt.CodeGPTKeys.IS_FETCHING_COMPLETION
import ee.carlrobert.codegpt.CodeGPTKeys.REMAINING_EDITOR_COMPLETION
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings
@ -25,7 +26,6 @@ import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.launch
import okhttp3.sse.EventSource
import java.util.concurrent.atomic.AtomicReference
import kotlin.math.min
import kotlin.time.Duration
import kotlin.time.DurationUnit
import kotlin.time.toDuration
@ -48,37 +48,77 @@ class DebouncedCodeCompletionProvider : DebouncedInlineCompletionProvider() {
get() = CodeCompletionProviderPresentation()
override suspend fun getSuggestionDebounced(request: InlineCompletionRequest): InlineCompletionSuggestion {
return if (service<ConfigurationSettings>().state.codeCompletionSettings.multiLineEnabled) {
getMultiLineSuggestionDebounced(request)
} else {
getSingleLineSuggestionDebounced(request)
}
}
private fun getSingleLineSuggestionDebounced(request: InlineCompletionRequest): InlineCompletionSuggestion {
val editor = request.editor
val remainingCompletion = REMAINING_EDITOR_COMPLETION.get(editor) ?: ""
if (request.event is InlineCompletionEvent.DirectCall && remainingCompletion.isNotEmpty()) {
if (request.event is InlineCompletionEvent.DirectCall && remainingCompletion.isNotEmpty()
) {
return sendNextSuggestion(remainingCompletion.extractUntilNewline(), request)
}
val project = editor.project
return getSuggestionDebounced(
request,
CompletionType.SINGLE_LINE
) { project, infillRequest ->
project.service<CodeCompletionService>()
.getCodeCompletionAsync(
infillRequest,
CodeCompletionSingleLineEventListener(request.editor, infillRequest) {
trySend(it)
}
)
}
}
private fun getMultiLineSuggestionDebounced(request: InlineCompletionRequest): InlineCompletionSuggestion {
return getSuggestionDebounced(
request,
CompletionType.MULTI_LINE
) { project, infillRequest ->
project.service<CodeCompletionService>()
.getCodeCompletionAsync(
infillRequest,
CodeCompletionMultiLineEventListener(request) {
trySend(InlineCompletionGrayTextElement(it))
}
)
}
}
private fun getSuggestionDebounced(
request: InlineCompletionRequest,
completionType: CompletionType,
fetchCompletion: ProducerScope<InlineCompletionElement>.(Project, InfillRequest) -> EventSource
): InlineCompletionSuggestion {
val project = request.editor.project
if (project == null) {
logger.error("Could not find project")
return InlineCompletionSuggestion.Default(emptyFlow())
}
if (LookupManager.getActiveLookup(editor) != null) {
if (LookupManager.getActiveLookup(request.editor) != null) {
return InlineCompletionSuggestion.Default(emptyFlow())
}
if (LookupManager.getActiveLookup(request.editor) != null) {
return InlineCompletionSuggestion.Default(emptyFlow())
}
IS_FETCHING_COMPLETION.set(request.editor, true)
request.editor.project?.messageBus
?.syncPublisher(CodeCompletionProgressNotifier.CODE_COMPLETION_PROGRESS_TOPIC)
?.loading(true)
return InlineCompletionSuggestion.Default(channelFlow {
REMAINING_EDITOR_COMPLETION.set(request.editor, "")
IS_FETCHING_COMPLETION.set(request.editor, true)
request.editor.project?.messageBus
?.syncPublisher(CodeCompletionProgressNotifier.CODE_COMPLETION_PROGRESS_TOPIC)
?.loading(true)
val infillRequest = InfillRequestUtil.buildInfillRequest(request)
val call = project.service<CodeCompletionService>()
.getCodeCompletionAsync(
infillRequest,
getEventListener(request.editor, infillRequest)
)
currentCallRef.set(call)
val infillRequest = InfillRequestUtil.buildInfillRequest(request, completionType)
currentCallRef.set(fetchCompletion(project, infillRequest))
awaitClose { currentCallRef.getAndSet(null)?.cancel() }
})
}
@ -115,70 +155,10 @@ class DebouncedCodeCompletionProvider : DebouncedInlineCompletionProvider() {
return event is InlineCompletionEvent.DocumentChange || containsActiveCompletion
}
private fun ProducerScope<InlineCompletionElement>.getEventListener(
editor: Editor,
infillRequest: InfillRequest
) = object : CodeCompletionEventListener(editor) {
override fun onLineReceived(completionLine: String) {
runInEdt {
var editorLineSuffix = editor.getLineSuffixAfterCaret()
if (editorLineSuffix.isBlank()) {
trySend(
CodeCompletionTextElement(
completionLine,
infillRequest.caretOffset,
TextRange.from(infillRequest.caretOffset, completionLine.length),
)
)
} else {
var caretShift = 0
// TODO: Handle other scenarios
val processedCompletion =
if (completionLine.startsWith(editorLineSuffix.first())) {
caretShift++
editorLineSuffix = editorLineSuffix.substring(1)
completionLine.substring(1)
} else {
completionLine
}
val completionWithRemovedSuffix =
processedCompletion.removeSuffix(editorLineSuffix)
trySend(
CodeCompletionTextElement(
completionWithRemovedSuffix,
infillRequest.caretOffset + caretShift,
TextRange.from(
infillRequest.caretOffset + caretShift,
completionWithRemovedSuffix.length
),
caretShift,
completionLine
)
)
}
}
}
}
private fun Editor.getLineSuffixAfterCaret(): String {
val lineEndOffset = document.getLineEndOffset(document.getLineNumber(caretModel.offset))
return document.getText(
TextRange(
caretModel.offset,
min(lineEndOffset + 1, document.textLength)
)
)
}
private fun sendNextSuggestion(
nextCompletion: String,
request: InlineCompletionRequest
): InlineCompletionSuggestion {
return InlineCompletionSuggestion.Default(channelFlow {
launch {
trySend(

View file

@ -8,23 +8,17 @@ import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.codecompletions.psi.filePath
import ee.carlrobert.codegpt.codecompletions.psi.readText
const val MAX_PROMPT_TOKENS = 128
const val MAX_PROMPT_TOKENS = 256
class InfillRequest private constructor(
val prefix: String,
val suffix: String,
val caretOffset: Int,
val fileDetails: FileDetails?,
val vcsDetails: VcsDetails?,
val context: InfillContext?
val context: InfillContext?,
val stopTokens: List<String>,
) {
companion object {
fun builder(prefix: String, suffix: String, caretOffset: Int) =
Builder(prefix, suffix, caretOffset)
}
data class VcsDetails(val stagedDiff: String? = null, val unstagedDiff: String? = null)
data class FileDetails(val fileContent: String, val fileExtension: String? = null)
class Builder {
@ -32,35 +26,78 @@ class InfillRequest private constructor(
private val suffix: String
private val caretOffset: Int
private var fileDetails: FileDetails? = null
private var vcsDetails: VcsDetails? = null
private var additionalContext: String? = null
private var context: InfillContext? = null
private var stopTokens: List<String>
constructor(prefix: String, suffix: String, caretOffset: Int) {
constructor(
prefix: String,
suffix: String,
caretOffset: Int,
type: CompletionType = CompletionType.MULTI_LINE
) {
this.prefix = prefix
this.suffix = suffix
this.caretOffset = caretOffset
this.stopTokens = getStopTokens(type)
}
constructor(document: Document, caretOffset: Int) {
constructor(
document: Document,
caretOffset: Int,
type: CompletionType = CompletionType.MULTI_LINE
) {
prefix =
document.getText(TextRange(0, caretOffset))
.truncateText(MAX_PROMPT_TOKENS, false)
suffix =
document.getText(
TextRange(
caretOffset,
document.textLength
)
).truncateText(MAX_PROMPT_TOKENS)
document.getText(TextRange(caretOffset, document.textLength))
.truncateText(MAX_PROMPT_TOKENS)
this.caretOffset = caretOffset
this.stopTokens = getStopTokens(type)
}
fun fileDetails(fileDetails: FileDetails) = apply { this.fileDetails = fileDetails }
fun vcsDetails(vcsDetails: VcsDetails) = apply { this.vcsDetails = vcsDetails }
fun additionalContext(additionalContext: String) =
apply { this.additionalContext = additionalContext }
fun context(context: InfillContext) = apply { this.context = context }
fun build() =
InfillRequest(prefix, suffix, caretOffset, fileDetails, vcsDetails, context)
private fun getStopTokens(type: CompletionType): List<String> {
var whitespaceCount = 0
val lineSuffix = suffix
.takeWhile { char ->
if (char == '\n') false
else if (char.isWhitespace()) whitespaceCount++ < 2
else whitespaceCount < 2
}
val baseTokens = when (type) {
CompletionType.SINGLE_LINE -> emptyList()
else -> listOf("\n\n")
}
return if (lineSuffix.isNotEmpty()) {
baseTokens + lineSuffix
} else {
baseTokens
}
}
fun build(): InfillRequest {
val modifiedPrefix = if (!additionalContext.isNullOrEmpty()) {
"/*\n${additionalContext}\n*/\n\n$prefix"
} else {
prefix
}
return InfillRequest(
modifiedPrefix,
suffix,
caretOffset,
fileDetails,
context,
stopTokens,
)
}
}
}
@ -82,4 +119,9 @@ class ContextElement(val psiElement: PsiElement) {
fun String.truncateText(maxTokens: Int, fromStart: Boolean = true): String {
return service<EncodingManager>().truncateText(this, maxTokens, fromStart)
}
}
enum class CompletionType {
SINGLE_LINE,
MULTI_LINE,
}

View file

@ -4,6 +4,7 @@ import com.intellij.codeInsight.inline.completion.InlineCompletionRequest
import com.intellij.openapi.application.readAction
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.thisLogger
import com.intellij.openapi.fileEditor.FileEditorManager
import com.intellij.openapi.vcs.VcsException
import com.intellij.refactoring.suggested.startOffset
import ee.carlrobert.codegpt.EncodingManager
@ -15,36 +16,42 @@ import ee.carlrobert.codegpt.util.GitUtil
object InfillRequestUtil {
private val logger = thisLogger()
suspend fun buildInfillRequest(request: InlineCompletionRequest): InfillRequest {
suspend fun buildInfillRequest(
request: InlineCompletionRequest,
type: CompletionType
): InfillRequest {
val caretOffset = readAction { request.editor.caretModel.offset }
val infillRequestBuilder = InfillRequest.Builder(request.document, caretOffset)
val infillRequestBuilder = InfillRequest.Builder(request.document, caretOffset, type)
.fileDetails(
InfillRequest.FileDetails(
request.document.text,
request.file.virtualFile.extension
)
)
val project = request.editor.project ?: return infillRequestBuilder.build()
val project = request.editor.project ?: return infillRequestBuilder.build()
val repository = GitUtil.getProjectRepository(project)
if (service<ConfigurationSettings>().state.autocompletionGitContextEnabled && repository != null) {
if (repository != null) {
try {
val stagedDiff = GitUtil.getStagedDiff(project, repository)
val unstagedDiff = GitUtil.getUnstagedDiff(project, repository)
if (stagedDiff.isNotEmpty() || unstagedDiff.isNotEmpty()) {
infillRequestBuilder.vcsDetails(
InfillRequest.VcsDetails(
stagedDiff.joinToString("\n"),
unstagedDiff.joinToString("\n")
)
)
if (unstagedDiff.isNotEmpty()) {
val openedEditorFileNames =
FileEditorManager.getInstance(project).openFiles.map { it.name }
val additionalContext = unstagedDiff
.filter {
it.fileName != request.file.virtualFile.name && it.fileName in openedEditorFileNames
}
.joinToString("\n") { "${it.fileName}\n${it.content}" }
infillRequestBuilder.additionalContext(additionalContext)
}
} catch (e: VcsException) {
logger.error("Failed to get git context", e)
}
}
getInfillContext(request, caretOffset)?.let { infillRequestBuilder.context(it) }
if (service<ConfigurationSettings>().state.codeCompletionSettings.contextAwareEnabled) {
getInfillContext(request, caretOffset)?.let { infillRequestBuilder.context(it) }
}
return infillRequestBuilder.build()
}
@ -54,14 +61,8 @@ object InfillRequestUtil {
caretOffset: Int
): InfillContext? {
val infillContext =
if (service<ConfigurationSettings>().state.autocompletionContextAwareEnabled)
service<CompletionContextService>().findContext(request.editor, caretOffset)
else null
if (infillContext == null) {
return null
}
service<CompletionContextService>().findContext(request.editor, caretOffset)
?: return null
val caretInEnclosingElement =
caretOffset - infillContext.enclosingElement.psiElement.startOffset
val entireText = infillContext.enclosingElement.psiElement.readText()

View file

@ -1,4 +1,4 @@
package ee.carlrobert.codegpt.settings.configuration
package ee.carlrobert.codegpt.settings
import com.intellij.openapi.project.Project
import git4idea.GitUtil

View file

@ -0,0 +1,54 @@
package ee.carlrobert.codegpt.settings.configuration
import com.intellij.openapi.components.service
import com.intellij.openapi.ui.DialogPanel
import com.intellij.ui.components.JBCheckBox
import com.intellij.ui.dsl.builder.panel
import ee.carlrobert.codegpt.CodeGPTBundle
class CodeCompletionConfigurationForm {
private val multiLineCompletionsCheckBox = JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.section.codeCompletion.multiLineCompletions.title"),
service<ConfigurationSettings>().state.codeCompletionSettings.multiLineEnabled
)
private val treeSitterProcessingCheckBox = JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.section.codeCompletion.postProcess.title"),
service<ConfigurationSettings>().state.codeCompletionSettings.treeSitterProcessingEnabled
)
private val gitDiffCheckBox = JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.section.codeCompletion.gitDiff.title"),
service<ConfigurationSettings>().state.codeCompletionSettings.gitDiffEnabled
)
fun createPanel(): DialogPanel {
return panel {
row {
cell(multiLineCompletionsCheckBox)
.comment(CodeGPTBundle.get("configurationConfigurable.section.codeCompletion.multiLineCompletions.description"))
}
row {
cell(treeSitterProcessingCheckBox)
.comment(CodeGPTBundle.get("configurationConfigurable.section.codeCompletion.postProcess.description"))
}
row {
cell(gitDiffCheckBox)
.comment(CodeGPTBundle.get("configurationConfigurable.section.codeCompletion.gitDiff.description"))
}
}
}
fun resetForm(prevState: CodeCompletionSettingsState) {
multiLineCompletionsCheckBox.isSelected = prevState.multiLineEnabled
treeSitterProcessingCheckBox.isSelected = prevState.treeSitterProcessingEnabled
gitDiffCheckBox.isSelected = prevState.gitDiffEnabled
}
fun getFormState(): CodeCompletionSettingsState {
return CodeCompletionSettingsState().apply {
this.multiLineEnabled = multiLineCompletionsCheckBox.isSelected
this.treeSitterProcessingEnabled = treeSitterProcessingCheckBox.isSelected
this.gitDiffEnabled = gitDiffCheckBox.isSelected
}
}
}

View file

@ -31,12 +31,17 @@ class ConfigurationSettingsState : BaseState() {
var methodNameGenerationEnabled by property(true)
var captureCompileErrors by property(true)
var autoFormattingEnabled by property(true)
var autocompletionPostProcessingEnabled by property(false)
var autocompletionContextAwareEnabled by property(false)
var autocompletionGitContextEnabled by property(true)
var tableData by map<String, String>()
var codeCompletionSettings by property(CodeCompletionSettingsState())
init {
tableData.putAll(EditorActionsUtil.DEFAULT_ACTIONS)
}
}
class CodeCompletionSettingsState : BaseState() {
var multiLineEnabled by property(true)
var treeSitterProcessingEnabled by property(true)
var gitDiffEnabled by property(true)
var contextAwareEnabled by property(false)
}

View file

@ -4,12 +4,12 @@ import com.intellij.openapi.components.Service
import com.intellij.openapi.components.Service.Level.PROJECT
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.settings.configuration.BranchNamePlaceholderStrategy
import ee.carlrobert.codegpt.settings.configuration.DatePlaceholderStrategy
import ee.carlrobert.codegpt.settings.configuration.Placeholder
import ee.carlrobert.codegpt.settings.configuration.Placeholder.BRANCH_NAME
import ee.carlrobert.codegpt.settings.configuration.Placeholder.DATE_ISO_8601
import ee.carlrobert.codegpt.settings.configuration.PlaceholderStrategy
import ee.carlrobert.codegpt.settings.BranchNamePlaceholderStrategy
import ee.carlrobert.codegpt.settings.DatePlaceholderStrategy
import ee.carlrobert.codegpt.settings.Placeholder
import ee.carlrobert.codegpt.settings.Placeholder.BRANCH_NAME
import ee.carlrobert.codegpt.settings.Placeholder.DATE_ISO_8601
import ee.carlrobert.codegpt.settings.PlaceholderStrategy
@Service(PROJECT)
class CommitMessageTemplate private constructor(project: Project) {

View file

@ -15,7 +15,7 @@ import ee.carlrobert.codegpt.codecompletions.CodeCompletionRequestFactory
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate
import ee.carlrobert.codegpt.codecompletions.InfillRequest
import ee.carlrobert.codegpt.completions.CompletionRequestService
import ee.carlrobert.codegpt.settings.configuration.Placeholder
import ee.carlrobert.codegpt.settings.Placeholder
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceCodeCompletionSettingsState
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceFormTabbedPane
import ee.carlrobert.codegpt.ui.OverlayUtil

View file

@ -2,6 +2,7 @@ package ee.carlrobert.codegpt.util
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.PathManager
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.application.runUndoTransparentWriteAction
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.components.service
@ -13,6 +14,7 @@ import com.intellij.openapi.fileEditor.FileEditorManager
import com.intellij.openapi.fileEditor.TextEditor
import com.intellij.openapi.fileEditor.impl.FileEditorManagerImpl
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.TextRange
import com.intellij.openapi.util.text.StringUtil
import com.intellij.psi.PsiDocumentManager
import com.intellij.psi.codeStyle.CodeStyleManager
@ -20,6 +22,7 @@ import com.intellij.testFramework.LightVirtualFile
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import kotlin.math.min
object EditorUtil {
@JvmStatic
@ -108,6 +111,23 @@ object EditorUtil {
}
}
fun String.adjustWhitespaces(editor: Editor): String {
val document = editor.document
val adjustedLine = runReadAction {
val lineNumber = document.getLineNumber(editor.caretModel.offset)
val editorLine = document.getText(
TextRange(
document.getLineStartOffset(lineNumber),
document.getLineEndOffset(lineNumber)
)
)
ee.carlrobert.codegpt.util.StringUtil.adjustWhitespace(this, editorLine)
}
return if (adjustedLine.length != this.length) adjustedLine else this
}
@JvmStatic
fun reformatDocument(
project: Project,

View file

@ -17,38 +17,17 @@ object GitUtil {
private val logger = thisLogger()
@Throws(VcsException::class)
@JvmStatic
fun getStagedDiff(
project: Project,
gitRepository: GitRepository,
includedVersionedFilePaths: List<String> = emptyList()
): List<String> {
return getGitDiff(project, gitRepository, includedVersionedFilePaths, true)
}
@Throws(VcsException::class)
@JvmStatic
fun getUnstagedDiff(
project: Project,
gitRepository: GitRepository,
includedUnversionedFilePaths: List<String> = emptyList()
): List<String> {
return getGitDiff(project, gitRepository, includedUnversionedFilePaths, false)
}
private fun getGitDiff(
project: Project,
gitRepository: GitRepository,
filePaths: List<String>,
staged: Boolean
): List<String> {
filePaths: List<String> = emptyList(),
): List<GitDiffDetails> {
val handler = GitLineHandler(project, gitRepository.root, GitCommand.DIFF)
if (staged) {
handler.addParameters("--cached")
}
handler.addParameters(
"--unified=2",
"--cached",
"--unified=1",
"--diff-filter=AM",
"--no-prefix",
"--no-color",
@ -58,10 +37,22 @@ object GitUtil {
handler.addParameters(path)
}
val commandResult = Git.getInstance().runCommand(handler)
return filterDiffOutput(commandResult.output)
return Git.getInstance().runCommand(handler).outputAsJoinedString
.split("(?=(diff --git [^\n]+))".toRegex())
.filter { it.isNotEmpty() }
.map { diffLine ->
val lines = diffLine.lines()
val fileName = lines.first().split(" ").last().substringAfterLast("/")
val content = lines
.filter { line -> line.isNotEmpty() && !line.startsWith("+++") && !line.startsWith("---") }
.joinToString("\n")
GitDiffDetails(fileName, content)
}
.filter { it.content.isNotEmpty() }
}
data class GitDiffDetails(val fileName: String, val content: String)
@Throws(VcsException::class)
@JvmStatic
fun getProjectRepository(project: Project): GitRepository? {

View file

@ -129,6 +129,13 @@ configurationConfigurable.section.assistant.temperatureField.comment=The value o
configurationConfigurable.section.assistant.maxTokensField.label=Max completion tokens:
configurationConfigurable.section.assistant.maxTokensField.comment=The maximum capacity for completion.
configurationConfigurable.section.assistant.llamacppParams.title=Configuration Options for llama.cpp
configurationConfigurable.section.codeCompletion.title=Code Completion
configurationConfigurable.section.codeCompletion.multiLineCompletions.title=Enable multi-line completions
configurationConfigurable.section.codeCompletion.multiLineCompletions.description=If checked, the completion will be able to span multiple lines.
configurationConfigurable.section.codeCompletion.postProcess.title=Enable tree-sitter post-processing
configurationConfigurable.section.codeCompletion.postProcess.description=If checked, the completion will be post-processed using the tree-sitter parser.
configurationConfigurable.section.codeCompletion.gitDiff.title=Enable git diff context
configurationConfigurable.section.codeCompletion.gitDiff.description=If checked, the user's most recent unstaged git diff will be included when requesting completion.
settingsConfigurable.service.llama.topK.label=Top K:
settingsConfigurable.service.llama.topK.comment=Limit the next token selection to the K most probable tokens (default: 40)
settingsConfigurable.service.llama.topP.label=Top P:
@ -144,13 +151,6 @@ settingsConfigurable.service.custom.openai.linkToDocs=Link to API docs
settingsConfigurable.service.custom.openai.connectionSuccess=Connection successful.
settingsConfigurable.service.custom.openai.connectionFailed=Connection failed.
settingsConfigurable.service.ollama.models.refresh=Refresh Models
configurationConfigurable.section.commitMessage.title=Commit Message Template
configurationConfigurable.section.commitMessage.systemPromptField.label=Prompt template:
configurationConfigurable.section.inlineCompletion.title=Inline Completion
configurationConfigurable.section.inlineCompletion.systemPromptField.label=Prompt:
configurationConfigurable.section.inlineCompletion.systemPromptField.comment=Custom system prompt used for inline code generation (Fill in the Middle (FIM) template).<br/>The {pre}, {suf} and {mid} are replaced depending on the used Model's FIM template.
configurationConfigurable.section.inlineCompletion.delay.label=Delay:
configurationConfigurable.section.inlineCompletion.delay.comment=Inline completion is requested if user is idle for x milliseconds
advancedSettingsConfigurable.displayName=CodeGPT: Advanced Settings
advancedSettingsConfigurable.proxy.title=HTTP/SOCKS Proxy
advancedSettingsConfigurable.proxy.typeComboBoxField.label=Proxy:
@ -281,4 +281,4 @@ suggestionGroupItem.git.displayName=Git
suggestionActionItem.webSearch.displayName=Web
suggestionActionItem.viewDocumentations.displayName=View all docs
suggestionActionItem.createPersona.displayName=Create new persona
suggestionActionItem.createDocumentation.displayName=Create new documentation
suggestionActionItem.createDocumentation.displayName=Create new documentation

View file

@ -1,9 +1,12 @@
package ee.carlrobert.codegpt.codecompletions
import com.intellij.codeInsight.inline.completion.session.InlineCompletionSession.Companion.getOrNull
import com.intellij.openapi.components.service
import com.intellij.openapi.editor.VisualPosition
import com.intellij.testFramework.PlatformTestUtil
import ee.carlrobert.codegpt.CodeGPTKeys.REMAINING_EDITOR_COMPLETION
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings
import ee.carlrobert.codegpt.util.file.FileUtil
import ee.carlrobert.llm.client.http.RequestEntity
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
@ -15,21 +18,24 @@ class CodeCompletionServiceTest : IntegrationTest() {
fun `test code completion with CodeGPT provider`() {
useCodeGPTService()
service<ConfigurationSettings>().state.codeCompletionSettings.multiLineEnabled = false
myFixture.configureByText(
"CompletionTest.java",
FileUtil.getResourceContent("/codecompletions/code-completion-file.txt")
)
myFixture.editor.caretModel.moveToVisualPosition(VisualPosition(3, 0))
val prefix = """
${"z".repeat(245)}
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
[INPUT]
p
""".trimIndent() // 128 tokens
""".trimIndent()
val suffix = """
[\INPUT]
${"z".repeat(247)}
""".trimIndent() // 128 tokens
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
""".trimIndent()
expectCodeGPT(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/code/completions")
assertThat(request.method).isEqualTo("POST")
@ -52,27 +58,30 @@ class CodeCompletionServiceTest : IntegrationTest() {
fun `test code completion with OpenAI provider`() {
useOpenAIService()
service<ConfigurationSettings>().state.codeCompletionSettings.multiLineEnabled = false
myFixture.configureByText(
"CompletionTest.java",
FileUtil.getResourceContent("/codecompletions/code-completion-file.txt")
)
myFixture.editor.caretModel.moveToVisualPosition(VisualPosition(3, 0))
val prefix = """
${"z".repeat(245)}
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
[INPUT]
p
""".trimIndent() // 128 tokens
""".trimIndent()
val suffix = """
[\INPUT]
${"z".repeat(247)}
""".trimIndent() // 128 tokens
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
""".trimIndent()
expectOpenAI(StreamHttpExchange { request: RequestEntity ->
assertThat(request.uri.path).isEqualTo("/v1/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.body)
.extracting("model", "prompt", "suffix", "max_tokens")
.containsExactly("gpt-3.5-turbo-instruct", prefix, suffix, 128)
.containsExactly("gpt-3.5-turbo-instruct", prefix, suffix, 80)
listOf(
jsonMapResponse("choices", jsonArray(jsonMap("text", "ublic "))),
jsonMapResponse("choices", jsonArray(jsonMap("text", "void"))),
@ -89,6 +98,7 @@ class CodeCompletionServiceTest : IntegrationTest() {
fun `_test apply inline suggestions without initial following text`() {
useCodeGPTService()
service<ConfigurationSettings>().state.codeCompletionSettings.multiLineEnabled = false
myFixture.configureByText(
"CompletionTest.java",
"class Node {\n "
@ -205,6 +215,7 @@ class CodeCompletionServiceTest : IntegrationTest() {
fun `_test apply inline suggestions with initial following text`() {
useCodeGPTService()
service<ConfigurationSettings>().state.codeCompletionSettings.multiLineEnabled = false
myFixture.configureByText(
"CompletionTest.java",
"if () {\n \n} else {\n}"
@ -275,6 +286,7 @@ class CodeCompletionServiceTest : IntegrationTest() {
fun `test adjust completion line whitespaces`() {
useCodeGPTService()
service<ConfigurationSettings>().state.codeCompletionSettings.multiLineEnabled = false
myFixture.configureByText(
"CompletionTest.java",
"class Node {\n" +