diff --git a/build.gradle.kts b/build.gradle.kts index 40d972c8..1d57e4bf 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -49,6 +49,7 @@ changelog { dependencies { implementation(project(":codegpt-core")) implementation(project(":codegpt-telemetry")) + implementation(project(":codegpt-treesitter")) implementation("com.fasterxml.jackson.datatype:jackson-datatype-jdk8:2.16.1") implementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.16.1") @@ -60,12 +61,7 @@ dependencies { implementation("org.apache.commons:commons-text:1.11.0") implementation("com.knuddels:jtokkit:1.0.0") - testImplementation("org.assertj:assertj-core:3.25.3") testImplementation("org.awaitility:awaitility:4.2.0") - testImplementation("org.junit.jupiter:junit-jupiter-params:5.10.2") - testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.10.2") - testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:5.10.2") - testRuntimeOnly("org.junit.vintage:junit-vintage-engine:5.10.2") } tasks.register("updateSubmodules") { diff --git a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts index b8df294a..429cc9db 100644 --- a/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/codegpt.java-conventions.gradle.kts @@ -24,6 +24,12 @@ checkstyle { dependencies { implementation("ee.carlrobert:llm-client:0.6.2") + + testImplementation("org.assertj:assertj-core:3.25.3") + testImplementation("org.junit.jupiter:junit-jupiter-params:5.10.2") + testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.10.2") + testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:5.10.2") + testRuntimeOnly("org.junit.vintage:junit-vintage-engine:5.10.2") } tasks { diff --git a/codegpt-treesitter/build.gradle.kts b/codegpt-treesitter/build.gradle.kts new file mode 100644 index 00000000..ed515253 --- /dev/null +++ b/codegpt-treesitter/build.gradle.kts @@ -0,0 +1,40 @@ +plugins { + id("codegpt.java-conventions") +} + +dependencies { + implementation("io.github.bonede:tree-sitter:0.21.0") + implementation("io.github.bonede:tree-sitter-erlang:0.1.0") + implementation("io.github.bonede:tree-sitter-elixir:0.1.1") + implementation("io.github.bonede:tree-sitter-dockerfile:0.1.2") + implementation("io.github.bonede:tree-sitter-dart:master") + implementation("io.github.bonede:tree-sitter-css:0.20.0") + implementation("io.github.bonede:tree-sitter-cpp:0.20.3") + implementation("io.github.bonede:tree-sitter-c-sharp:0.20.0") + implementation("io.github.bonede:tree-sitter-fortran:master") + implementation("io.github.bonede:tree-sitter-gitattributes:0.1.3") + implementation("io.github.bonede:tree-sitter-go:0.20.0") + implementation("io.github.bonede:tree-sitter-graphql:master") + implementation("io.github.bonede:tree-sitter-html:0.19.0") + implementation("io.github.bonede:tree-sitter-javascript:0.20.1") + implementation("io.github.bonede:tree-sitter-json:0.20.1") + implementation("io.github.bonede:tree-sitter-kotlin:0.3.1") + implementation("io.github.bonede:tree-sitter-latex:0.3.0") + implementation("io.github.bonede:tree-sitter-lua:2.1.3") + implementation("io.github.bonede:tree-sitter-m68k:0.2.7") + implementation("io.github.bonede:tree-sitter-markdown:0.7.1") + implementation("io.github.bonede:tree-sitter-objc:main") + implementation("io.github.bonede:tree-sitter-perl:0.4.0") + implementation("io.github.bonede:tree-sitter-ruby:0.19.0") + implementation("io.github.bonede:tree-sitter-rust:0.20.4") + implementation("io.github.bonede:tree-sitter-scala:0.20.2") + implementation("io.github.bonede:tree-sitter-scss:1.0.0") + implementation("io.github.bonede:tree-sitter-svelte:0.11.0") + implementation("io.github.bonede:tree-sitter-swift:0.3.6") + implementation("io.github.bonede:tree-sitter-yaml:0.5.0") + implementation("io.github.bonede:tree-sitter-java:0.20.2") + implementation("io.github.bonede:tree-sitter-python:0.20.4") + implementation("io.github.bonede:tree-sitter-php:0.20.0") + implementation("io.github.bonede:tree-sitter-typescript:0.20.3") + implementation("io.github.bonede:tree-sitter-query:0.1.0") +} \ No newline at end of file diff --git a/codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParser.java b/codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParser.java new file mode 100644 index 00000000..6b388f2d --- /dev/null +++ b/codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParser.java @@ -0,0 +1,55 @@ +package ee.carlrobert.codegpt.treesitter; + +import org.treesitter.TSLanguage; +import org.treesitter.TSNode; +import org.treesitter.TSParser; +import org.treesitter.TSTree; + +public class CodeCompletionParser { + + protected final TSLanguage language; + + public CodeCompletionParser(TSLanguage language) { + this.language = language; + } + + public String parse(String prefix, String suffix, String output) { + var result = new StringBuilder(output); + while (result.length() > 0) { + if (containsSyntaxErrors(prefix + result + suffix)) { + result.deleteCharAt(result.length() - 1); + } else { + return result.toString(); + } + } + + if (output.contains("\n")) { + return parse(prefix, suffix, output.substring(0, output.indexOf("\n"))); + } + + return output; + } + + private boolean containsSyntaxErrors(String input) { + return containsSyntaxErrors(getTree(input).getRootNode()); + } + + private boolean containsSyntaxErrors(TSNode node) { + if (node.isMissing() || node.hasError()) { + return true; + } + + for (int i = 0; i < node.getChildCount(); i++) { + if (containsSyntaxErrors(node.getChild(i))) { + return true; + } + } + return false; + } + + private TSTree getTree(String input) { + var parser = new TSParser(); + parser.setLanguage(language); + return parser.parseString(null, input); + } +} diff --git a/codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParserFactory.java b/codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParserFactory.java new file mode 100644 index 00000000..9bbe6dd4 --- /dev/null +++ b/codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParserFactory.java @@ -0,0 +1,145 @@ +package ee.carlrobert.codegpt.treesitter; + +import org.treesitter.TSLanguage; +import org.treesitter.TreeSitterCSharp; +import org.treesitter.TreeSitterCpp; +import org.treesitter.TreeSitterCss; +import org.treesitter.TreeSitterDart; +import org.treesitter.TreeSitterDockerfile; +import org.treesitter.TreeSitterElixir; +import org.treesitter.TreeSitterErlang; +import org.treesitter.TreeSitterFortran; +import org.treesitter.TreeSitterGitattributes; +import org.treesitter.TreeSitterGo; +import org.treesitter.TreeSitterGraphql; +import org.treesitter.TreeSitterHtml; +import org.treesitter.TreeSitterJava; +import org.treesitter.TreeSitterJavascript; +import org.treesitter.TreeSitterJson; +import org.treesitter.TreeSitterKotlin; +import org.treesitter.TreeSitterLatex; +import org.treesitter.TreeSitterLua; +import org.treesitter.TreeSitterM68k; +import org.treesitter.TreeSitterMarkdown; +import org.treesitter.TreeSitterObjc; +import org.treesitter.TreeSitterPerl; +import org.treesitter.TreeSitterPhp; +import org.treesitter.TreeSitterPython; +import org.treesitter.TreeSitterRuby; +import org.treesitter.TreeSitterRust; +import org.treesitter.TreeSitterScala; +import org.treesitter.TreeSitterScss; +import org.treesitter.TreeSitterSvelte; +import org.treesitter.TreeSitterSwift; +import org.treesitter.TreeSitterTypescript; +import org.treesitter.TreeSitterYaml; + +public class CodeCompletionParserFactory { + + public static CodeCompletionParser getParserForFileExtension(String extension) + throws IllegalArgumentException { + return new CodeCompletionParser(getLanguageForExtension(extension)); + } + + private static TSLanguage getLanguageForExtension(String extension) { + switch (extension) { + case "java": + return new TreeSitterJava(); + case "php": + return new TreeSitterPhp(); + case "py": + return new TreeSitterPython(); + case "ts": + case "tsx": + return new TreeSitterTypescript(); + case "js": + case "jsx": + return new TreeSitterJavascript(); + case "c": + case "h": + case "cpp": + case "cxx": + case "cc": + case "c++": + case "hpp": + case "hxx": + case "hh": + case "h++": + return new TreeSitterCpp(); + case "cs": + return new TreeSitterCSharp(); + case "css": + return new TreeSitterCss(); + case "dart": + return new TreeSitterDart(); + case "dockerfile": + return new TreeSitterDockerfile(); + case "elixir": + case "ex": + case "exs": + return new TreeSitterElixir(); + case "erl": + case "hrl": + return new TreeSitterErlang(); + case "f90": + case "f95": + case "f03": + case "f08": + return new TreeSitterFortran(); + case "gitattributes": + return new TreeSitterGitattributes(); + case "go": + return new TreeSitterGo(); + case "graphql": + case "gql": + return new TreeSitterGraphql(); + case "html": + case "htm": + return new TreeSitterHtml(); + case "json": + return new TreeSitterJson(); + case "kotlin": + case "kt": + case "kts": + return new TreeSitterKotlin(); + case "latex": + case "tex": + return new TreeSitterLatex(); + case "lua": + return new TreeSitterLua(); + case "m68k": + return new TreeSitterM68k(); + case "markdown": + case "md": + return new TreeSitterMarkdown(); + case "objc": + case "m": + case "mm": + return new TreeSitterObjc(); + case "perl": + case "pl": + case "pm": + return new TreeSitterPerl(); + case "ruby": + case "rb": + return new TreeSitterRuby(); + case "rust": + case "rs": + return new TreeSitterRust(); + case "scala": + case "sc": + return new TreeSitterScala(); + case "scss": + return new TreeSitterScss(); + case "svelte": + return new TreeSitterSvelte(); + case "swift": + return new TreeSitterSwift(); + case "yml": + case "yaml": + return new TreeSitterYaml(); + default: + throw new IllegalArgumentException("Unsupported file extension: " + extension); + } + } +} diff --git a/codegpt-treesitter/src/test/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParserTest.java b/codegpt-treesitter/src/test/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParserTest.java new file mode 100644 index 00000000..e1cc1cd8 --- /dev/null +++ b/codegpt-treesitter/src/test/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParserTest.java @@ -0,0 +1,81 @@ +package ee.carlrobert.codegpt.treesitter; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.Test; + +public class CodeCompletionParserTest { + + @Test + public void shouldGetValidReturnValue() { + var prefix = "class Main {\n" + + " public int getRandomNumber() {\n" + + " return "; + var suffix = "\n" + + " }\n" + + "}"; + var output = "10;}\n}\npublic int getRandomNumber(int k) {"; + + var parsedResponse = CodeCompletionParserFactory + .getParserForFileExtension("java") + .parse(prefix, suffix, output); + + assertThat(parsedResponse).isEqualTo("10;"); + } + + @Test + public void shouldGetValidParenthesisValue() { + var prefix = "class Main {\n" + + " public int getRandomNumber(int "; + var suffix = ") {\n" + + " return 10;\n" + + " }\n" + + "}"; + var output = "prevNumber) {\n if() {"; + + var parsedResponse = CodeCompletionParserFactory + .getParserForFileExtension("java") + .parse(prefix, suffix, output); + + assertThat(parsedResponse).isEqualTo("prevNumber"); + } + + @Test + public void shouldHandleFieldDeclaration() { + var prefix = "class Main {\n" + + "\t\n" + + " private i"; + var suffix = "\n" + + "\n" + + " public int getRandomNumber(int prevNumber) {\n" + + " return Math.of()\n" + + " }\n" + + "}"; + var output = "nt randomNumber;\n" + + " \n" + + " public void get() {"; + + var result = CodeCompletionParserFactory + .getParserForFileExtension("java") + .parse(prefix, suffix, output); + + assertThat(result).isEqualTo("nt randomNumber;"); + } + + @Test + public void shouldHandleFormalParameters() { + var prefix = "class Main {\n" + + " public int getRandomNumber("; + var suffix = ") {\n" + + " return 10;\n" + + " }\n" + + "}"; + var output = "int prevNumber) }"; + + var result = CodeCompletionParserFactory + .getParserForFileExtension("java") + .parse(prefix, suffix, output); + + assertThat(result).isEqualTo("int prevNumber"); + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 8c8a189c..2921dd58 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,3 +1,4 @@ rootProject.name = "CodeGPT" include("codegpt-core") +include("codegpt-treesitter") include(":codegpt-telemetry") diff --git a/src/main/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionEventListener.java b/src/main/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionEventListener.java index 1c67c8a6..9b9b9a20 100644 --- a/src/main/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionEventListener.java +++ b/src/main/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionEventListener.java @@ -7,10 +7,13 @@ import com.intellij.notification.NotificationType; import com.intellij.notification.Notifications; import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.editor.Editor; +import com.intellij.openapi.editor.impl.EditorImpl; import com.intellij.openapi.progress.impl.BackgroundableProcessIndicator; import ee.carlrobert.codegpt.CodeGPTBundle; import ee.carlrobert.codegpt.actions.OpenSettingsAction; +import ee.carlrobert.codegpt.treesitter.CodeCompletionParserFactory; import ee.carlrobert.codegpt.ui.OverlayUtil; +import ee.carlrobert.codegpt.util.file.FileUtil; import ee.carlrobert.llm.client.openai.completion.ErrorDetails; import ee.carlrobert.llm.completion.CompletionEventListener; import java.io.IOException; @@ -25,14 +28,17 @@ class CodeCompletionEventListener implements CompletionEventListener { private final Editor editor; private final int caretOffset; + private final InfillRequestDetails requestDetails; private final BackgroundableProcessIndicator progressIndicator; public CodeCompletionEventListener( Editor editor, int caretOffset, + InfillRequestDetails requestDetails, @Nullable BackgroundableProcessIndicator progressIndicator) { this.editor = editor; this.caretOffset = caretOffset; + this.requestDetails = requestDetails; this.progressIndicator = progressIndicator; } @@ -42,15 +48,32 @@ class CodeCompletionEventListener implements CompletionEventListener { progressIndicator.processFinish(); } - PREVIOUS_INLAY_TEXT.set(editor, messageBuilder.toString()); + try { + var project = editor.getProject(); + if (project != null) { + var fileExtension = + FileUtil.getFileExtension(((EditorImpl) editor).getVirtualFile().getName()); + var processedOutput = CodeCompletionParserFactory + .getParserForFileExtension(fileExtension) + .parse( + requestDetails.getPrefix(), + requestDetails.getSuffix(), + messageBuilder.toString()); + handleComplete(processedOutput); + } + } catch (IllegalArgumentException e) { + handleComplete(messageBuilder.toString()); + } + } + + private void handleComplete(String processedOutput) { + PREVIOUS_INLAY_TEXT.set(editor, processedOutput); CodeGPTEditorManager.getInstance().disposeEditorInlays(editor); SwingUtilities.invokeLater(() -> { - if (editor.getCaretModel().getOffset() == caretOffset) { - var inlayText = messageBuilder.toString(); - if (!inlayText.isEmpty()) { - CodeCompletionService.getInstance(requireNonNull(editor.getProject())) - .addInlays(editor, caretOffset, inlayText); - } + if (editor.getCaretModel().getOffset() == caretOffset && !processedOutput.isEmpty()) { + CodeCompletionService + .getInstance(requireNonNull(editor.getProject())) + .addInlays(editor, caretOffset, processedOutput); } }); } diff --git a/src/main/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.java b/src/main/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.java index 3f1f8947..4e80f0e2 100644 --- a/src/main/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.java +++ b/src/main/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionService.java @@ -101,7 +101,7 @@ public final class CodeCompletionService implements Disposable { Void.class, (progressIndicator) -> CompletionRequestService.getInstance().getCodeCompletionAsync( request, - new CodeCompletionEventListener(editor, offset, progressIndicator)), + new CodeCompletionEventListener(editor, offset, request, progressIndicator)), 750, TimeUnit.MILLISECONDS); } @@ -156,7 +156,7 @@ public final class CodeCompletionService implements Disposable { if (inlay != null) { applyCompletion(editor, text, inlay.getOffset()); CodeGPTEditorManager.getInstance().disposeEditorInlays(editor); - return; + break; } } editor.putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, null); diff --git a/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java b/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java index 86eef8be..7247e59a 100644 --- a/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java +++ b/src/test/java/ee/carlrobert/codegpt/codecompletions/CodeCompletionServiceTest.java @@ -48,14 +48,11 @@ public class CodeCompletionServiceTest extends IntegrationTest { () -> { var singleLineInlayElement = editor.getUserData(SINGLE_LINE_INLAY); var multiLineInlayElement = editor.getUserData(MULTI_LINE_INLAY); - if (singleLineInlayElement != null && multiLineInlayElement != null) { + if (singleLineInlayElement != null && multiLineInlayElement == null) { var singleLine = ((InlayInlineElementRenderer) singleLineInlayElement.getRenderer()) .getInlayText(); - var multiLine = - ((InlayBlockElementRenderer) multiLineInlayElement.getRenderer()).getInlayText(); - return "TEST_SINGLE_LINE_OUTPUT".equals(singleLine) - && "TEST_MULTI_LINE_OUTPUT".equals(multiLine); + return "TEST_SINGLE_LINE_OUTPUT".equals(singleLine); } return false; }, 5);