feat: apply post-processing for code completions (#404)

This commit is contained in:
Carl-Robert 2024-03-11 23:13:10 +02:00 committed by GitHub
parent 12cf5198f8
commit 91dd7bdb43
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 363 additions and 19 deletions

View file

@ -49,6 +49,7 @@ changelog {
dependencies {
implementation(project(":codegpt-core"))
implementation(project(":codegpt-telemetry"))
implementation(project(":codegpt-treesitter"))
implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:2.16.1")
implementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.16.1")
@ -60,12 +61,7 @@ dependencies {
implementation("org.apache.commons:commons-text:1.11.0")
implementation("com.knuddels:jtokkit:1.0.0")
testImplementation("org.assertj:assertj-core:3.25.3")
testImplementation("org.awaitility:awaitility:4.2.0")
testImplementation("org.junit.jupiter:junit-jupiter-params:5.10.2")
testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.10.2")
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:5.10.2")
testRuntimeOnly("org.junit.vintage:junit-vintage-engine:5.10.2")
}
tasks.register<Exec>("updateSubmodules") {

View file

@ -24,6 +24,12 @@ checkstyle {
dependencies {
implementation("ee.carlrobert:llm-client:0.6.2")
testImplementation("org.assertj:assertj-core:3.25.3")
testImplementation("org.junit.jupiter:junit-jupiter-params:5.10.2")
testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.10.2")
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:5.10.2")
testRuntimeOnly("org.junit.vintage:junit-vintage-engine:5.10.2")
}
tasks {

View file

@ -0,0 +1,40 @@
plugins {
id("codegpt.java-conventions")
}
dependencies {
implementation("io.github.bonede:tree-sitter:0.21.0")
implementation("io.github.bonede:tree-sitter-erlang:0.1.0")
implementation("io.github.bonede:tree-sitter-elixir:0.1.1")
implementation("io.github.bonede:tree-sitter-dockerfile:0.1.2")
implementation("io.github.bonede:tree-sitter-dart:master")
implementation("io.github.bonede:tree-sitter-css:0.20.0")
implementation("io.github.bonede:tree-sitter-cpp:0.20.3")
implementation("io.github.bonede:tree-sitter-c-sharp:0.20.0")
implementation("io.github.bonede:tree-sitter-fortran:master")
implementation("io.github.bonede:tree-sitter-gitattributes:0.1.3")
implementation("io.github.bonede:tree-sitter-go:0.20.0")
implementation("io.github.bonede:tree-sitter-graphql:master")
implementation("io.github.bonede:tree-sitter-html:0.19.0")
implementation("io.github.bonede:tree-sitter-javascript:0.20.1")
implementation("io.github.bonede:tree-sitter-json:0.20.1")
implementation("io.github.bonede:tree-sitter-kotlin:0.3.1")
implementation("io.github.bonede:tree-sitter-latex:0.3.0")
implementation("io.github.bonede:tree-sitter-lua:2.1.3")
implementation("io.github.bonede:tree-sitter-m68k:0.2.7")
implementation("io.github.bonede:tree-sitter-markdown:0.7.1")
implementation("io.github.bonede:tree-sitter-objc:main")
implementation("io.github.bonede:tree-sitter-perl:0.4.0")
implementation("io.github.bonede:tree-sitter-ruby:0.19.0")
implementation("io.github.bonede:tree-sitter-rust:0.20.4")
implementation("io.github.bonede:tree-sitter-scala:0.20.2")
implementation("io.github.bonede:tree-sitter-scss:1.0.0")
implementation("io.github.bonede:tree-sitter-svelte:0.11.0")
implementation("io.github.bonede:tree-sitter-swift:0.3.6")
implementation("io.github.bonede:tree-sitter-yaml:0.5.0")
implementation("io.github.bonede:tree-sitter-java:0.20.2")
implementation("io.github.bonede:tree-sitter-python:0.20.4")
implementation("io.github.bonede:tree-sitter-php:0.20.0")
implementation("io.github.bonede:tree-sitter-typescript:0.20.3")
implementation("io.github.bonede:tree-sitter-query:0.1.0")
}

View file

@ -0,0 +1,55 @@
package ee.carlrobert.codegpt.treesitter;
import org.treesitter.TSLanguage;
import org.treesitter.TSNode;
import org.treesitter.TSParser;
import org.treesitter.TSTree;
public class CodeCompletionParser {
protected final TSLanguage language;
public CodeCompletionParser(TSLanguage language) {
this.language = language;
}
public String parse(String prefix, String suffix, String output) {
var result = new StringBuilder(output);
while (result.length() > 0) {
if (containsSyntaxErrors(prefix + result + suffix)) {
result.deleteCharAt(result.length() - 1);
} else {
return result.toString();
}
}
if (output.contains("\n")) {
return parse(prefix, suffix, output.substring(0, output.indexOf("\n")));
}
return output;
}
private boolean containsSyntaxErrors(String input) {
return containsSyntaxErrors(getTree(input).getRootNode());
}
private boolean containsSyntaxErrors(TSNode node) {
if (node.isMissing() || node.hasError()) {
return true;
}
for (int i = 0; i < node.getChildCount(); i++) {
if (containsSyntaxErrors(node.getChild(i))) {
return true;
}
}
return false;
}
private TSTree getTree(String input) {
var parser = new TSParser();
parser.setLanguage(language);
return parser.parseString(null, input);
}
}

View file

@ -0,0 +1,145 @@
package ee.carlrobert.codegpt.treesitter;
import org.treesitter.TSLanguage;
import org.treesitter.TreeSitterCSharp;
import org.treesitter.TreeSitterCpp;
import org.treesitter.TreeSitterCss;
import org.treesitter.TreeSitterDart;
import org.treesitter.TreeSitterDockerfile;
import org.treesitter.TreeSitterElixir;
import org.treesitter.TreeSitterErlang;
import org.treesitter.TreeSitterFortran;
import org.treesitter.TreeSitterGitattributes;
import org.treesitter.TreeSitterGo;
import org.treesitter.TreeSitterGraphql;
import org.treesitter.TreeSitterHtml;
import org.treesitter.TreeSitterJava;
import org.treesitter.TreeSitterJavascript;
import org.treesitter.TreeSitterJson;
import org.treesitter.TreeSitterKotlin;
import org.treesitter.TreeSitterLatex;
import org.treesitter.TreeSitterLua;
import org.treesitter.TreeSitterM68k;
import org.treesitter.TreeSitterMarkdown;
import org.treesitter.TreeSitterObjc;
import org.treesitter.TreeSitterPerl;
import org.treesitter.TreeSitterPhp;
import org.treesitter.TreeSitterPython;
import org.treesitter.TreeSitterRuby;
import org.treesitter.TreeSitterRust;
import org.treesitter.TreeSitterScala;
import org.treesitter.TreeSitterScss;
import org.treesitter.TreeSitterSvelte;
import org.treesitter.TreeSitterSwift;
import org.treesitter.TreeSitterTypescript;
import org.treesitter.TreeSitterYaml;
public class CodeCompletionParserFactory {
public static CodeCompletionParser getParserForFileExtension(String extension)
throws IllegalArgumentException {
return new CodeCompletionParser(getLanguageForExtension(extension));
}
private static TSLanguage getLanguageForExtension(String extension) {
switch (extension) {
case "java":
return new TreeSitterJava();
case "php":
return new TreeSitterPhp();
case "py":
return new TreeSitterPython();
case "ts":
case "tsx":
return new TreeSitterTypescript();
case "js":
case "jsx":
return new TreeSitterJavascript();
case "c":
case "h":
case "cpp":
case "cxx":
case "cc":
case "c++":
case "hpp":
case "hxx":
case "hh":
case "h++":
return new TreeSitterCpp();
case "cs":
return new TreeSitterCSharp();
case "css":
return new TreeSitterCss();
case "dart":
return new TreeSitterDart();
case "dockerfile":
return new TreeSitterDockerfile();
case "elixir":
case "ex":
case "exs":
return new TreeSitterElixir();
case "erl":
case "hrl":
return new TreeSitterErlang();
case "f90":
case "f95":
case "f03":
case "f08":
return new TreeSitterFortran();
case "gitattributes":
return new TreeSitterGitattributes();
case "go":
return new TreeSitterGo();
case "graphql":
case "gql":
return new TreeSitterGraphql();
case "html":
case "htm":
return new TreeSitterHtml();
case "json":
return new TreeSitterJson();
case "kotlin":
case "kt":
case "kts":
return new TreeSitterKotlin();
case "latex":
case "tex":
return new TreeSitterLatex();
case "lua":
return new TreeSitterLua();
case "m68k":
return new TreeSitterM68k();
case "markdown":
case "md":
return new TreeSitterMarkdown();
case "objc":
case "m":
case "mm":
return new TreeSitterObjc();
case "perl":
case "pl":
case "pm":
return new TreeSitterPerl();
case "ruby":
case "rb":
return new TreeSitterRuby();
case "rust":
case "rs":
return new TreeSitterRust();
case "scala":
case "sc":
return new TreeSitterScala();
case "scss":
return new TreeSitterScss();
case "svelte":
return new TreeSitterSvelte();
case "swift":
return new TreeSitterSwift();
case "yml":
case "yaml":
return new TreeSitterYaml();
default:
throw new IllegalArgumentException("Unsupported file extension: " + extension);
}
}
}

View file

@ -0,0 +1,81 @@
package ee.carlrobert.codegpt.treesitter;
import static org.assertj.core.api.Assertions.assertThat;
import org.junit.Test;
public class CodeCompletionParserTest {
@Test
public void shouldGetValidReturnValue() {
var prefix = "class Main {\n"
+ " public int getRandomNumber() {\n"
+ " return ";
var suffix = "\n"
+ " }\n"
+ "}";
var output = "10;}\n}\npublic int getRandomNumber(int k) {";
var parsedResponse = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);
assertThat(parsedResponse).isEqualTo("10;");
}
@Test
public void shouldGetValidParenthesisValue() {
var prefix = "class Main {\n"
+ " public int getRandomNumber(int ";
var suffix = ") {\n"
+ " return 10;\n"
+ " }\n"
+ "}";
var output = "prevNumber) {\n if() {";
var parsedResponse = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);
assertThat(parsedResponse).isEqualTo("prevNumber");
}
@Test
public void shouldHandleFieldDeclaration() {
var prefix = "class Main {\n"
+ "\t\n"
+ " private i";
var suffix = "\n"
+ "\n"
+ " public int getRandomNumber(int prevNumber) {\n"
+ " return Math.of()\n"
+ " }\n"
+ "}";
var output = "nt randomNumber;\n"
+ " \n"
+ " public void get() {";
var result = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);
assertThat(result).isEqualTo("nt randomNumber;");
}
@Test
public void shouldHandleFormalParameters() {
var prefix = "class Main {\n"
+ " public int getRandomNumber(";
var suffix = ") {\n"
+ " return 10;\n"
+ " }\n"
+ "}";
var output = "int prevNumber) }";
var result = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);
assertThat(result).isEqualTo("int prevNumber");
}
}

View file

@ -1,3 +1,4 @@
rootProject.name = "CodeGPT"
include("codegpt-core")
include("codegpt-treesitter")
include(":codegpt-telemetry")

View file

@ -7,10 +7,13 @@ import com.intellij.notification.NotificationType;
import com.intellij.notification.Notifications;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.editor.impl.EditorImpl;
import com.intellij.openapi.progress.impl.BackgroundableProcessIndicator;
import ee.carlrobert.codegpt.CodeGPTBundle;
import ee.carlrobert.codegpt.actions.OpenSettingsAction;
import ee.carlrobert.codegpt.treesitter.CodeCompletionParserFactory;
import ee.carlrobert.codegpt.ui.OverlayUtil;
import ee.carlrobert.codegpt.util.file.FileUtil;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.io.IOException;
@ -25,14 +28,17 @@ class CodeCompletionEventListener implements CompletionEventListener<String> {
private final Editor editor;
private final int caretOffset;
private final InfillRequestDetails requestDetails;
private final BackgroundableProcessIndicator progressIndicator;
public CodeCompletionEventListener(
Editor editor,
int caretOffset,
InfillRequestDetails requestDetails,
@Nullable BackgroundableProcessIndicator progressIndicator) {
this.editor = editor;
this.caretOffset = caretOffset;
this.requestDetails = requestDetails;
this.progressIndicator = progressIndicator;
}
@ -42,15 +48,32 @@ class CodeCompletionEventListener implements CompletionEventListener<String> {
progressIndicator.processFinish();
}
PREVIOUS_INLAY_TEXT.set(editor, messageBuilder.toString());
try {
var project = editor.getProject();
if (project != null) {
var fileExtension =
FileUtil.getFileExtension(((EditorImpl) editor).getVirtualFile().getName());
var processedOutput = CodeCompletionParserFactory
.getParserForFileExtension(fileExtension)
.parse(
requestDetails.getPrefix(),
requestDetails.getSuffix(),
messageBuilder.toString());
handleComplete(processedOutput);
}
} catch (IllegalArgumentException e) {
handleComplete(messageBuilder.toString());
}
}
private void handleComplete(String processedOutput) {
PREVIOUS_INLAY_TEXT.set(editor, processedOutput);
CodeGPTEditorManager.getInstance().disposeEditorInlays(editor);
SwingUtilities.invokeLater(() -> {
if (editor.getCaretModel().getOffset() == caretOffset) {
var inlayText = messageBuilder.toString();
if (!inlayText.isEmpty()) {
CodeCompletionService.getInstance(requireNonNull(editor.getProject()))
.addInlays(editor, caretOffset, inlayText);
}
if (editor.getCaretModel().getOffset() == caretOffset && !processedOutput.isEmpty()) {
CodeCompletionService
.getInstance(requireNonNull(editor.getProject()))
.addInlays(editor, caretOffset, processedOutput);
}
});
}

View file

@ -101,7 +101,7 @@ public final class CodeCompletionService implements Disposable {
Void.class,
(progressIndicator) -> CompletionRequestService.getInstance().getCodeCompletionAsync(
request,
new CodeCompletionEventListener(editor, offset, progressIndicator)),
new CodeCompletionEventListener(editor, offset, request, progressIndicator)),
750,
TimeUnit.MILLISECONDS);
}
@ -156,7 +156,7 @@ public final class CodeCompletionService implements Disposable {
if (inlay != null) {
applyCompletion(editor, text, inlay.getOffset());
CodeGPTEditorManager.getInstance().disposeEditorInlays(editor);
return;
break;
}
}
editor.putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, null);

View file

@ -48,14 +48,11 @@ public class CodeCompletionServiceTest extends IntegrationTest {
() -> {
var singleLineInlayElement = editor.getUserData(SINGLE_LINE_INLAY);
var multiLineInlayElement = editor.getUserData(MULTI_LINE_INLAY);
if (singleLineInlayElement != null && multiLineInlayElement != null) {
if (singleLineInlayElement != null && multiLineInlayElement == null) {
var singleLine =
((InlayInlineElementRenderer) singleLineInlayElement.getRenderer())
.getInlayText();
var multiLine =
((InlayBlockElementRenderer) multiLineInlayElement.getRenderer()).getInlayText();
return "TEST_SINGLE_LINE_OUTPUT".equals(singleLine)
&& "TEST_MULTI_LINE_OUTPUT".equals(multiLine);
return "TEST_SINGLE_LINE_OUTPUT".equals(singleLine);
}
return false;
}, 5);