feat: apply post-processing for code completions (#404)

This commit is contained in:
Carl-Robert 2024-03-11 23:13:10 +02:00 committed by GitHub
parent 12cf5198f8
commit 91dd7bdb43
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 363 additions and 19 deletions

View file

@ -0,0 +1,55 @@
package ee.carlrobert.codegpt.treesitter;
import org.treesitter.TSLanguage;
import org.treesitter.TSNode;
import org.treesitter.TSParser;
import org.treesitter.TSTree;
public class CodeCompletionParser {
protected final TSLanguage language;
public CodeCompletionParser(TSLanguage language) {
this.language = language;
}
public String parse(String prefix, String suffix, String output) {
var result = new StringBuilder(output);
while (result.length() > 0) {
if (containsSyntaxErrors(prefix + result + suffix)) {
result.deleteCharAt(result.length() - 1);
} else {
return result.toString();
}
}
if (output.contains("\n")) {
return parse(prefix, suffix, output.substring(0, output.indexOf("\n")));
}
return output;
}
private boolean containsSyntaxErrors(String input) {
return containsSyntaxErrors(getTree(input).getRootNode());
}
private boolean containsSyntaxErrors(TSNode node) {
if (node.isMissing() || node.hasError()) {
return true;
}
for (int i = 0; i < node.getChildCount(); i++) {
if (containsSyntaxErrors(node.getChild(i))) {
return true;
}
}
return false;
}
private TSTree getTree(String input) {
var parser = new TSParser();
parser.setLanguage(language);
return parser.parseString(null, input);
}
}

View file

@ -0,0 +1,145 @@
package ee.carlrobert.codegpt.treesitter;
import org.treesitter.TSLanguage;
import org.treesitter.TreeSitterCSharp;
import org.treesitter.TreeSitterCpp;
import org.treesitter.TreeSitterCss;
import org.treesitter.TreeSitterDart;
import org.treesitter.TreeSitterDockerfile;
import org.treesitter.TreeSitterElixir;
import org.treesitter.TreeSitterErlang;
import org.treesitter.TreeSitterFortran;
import org.treesitter.TreeSitterGitattributes;
import org.treesitter.TreeSitterGo;
import org.treesitter.TreeSitterGraphql;
import org.treesitter.TreeSitterHtml;
import org.treesitter.TreeSitterJava;
import org.treesitter.TreeSitterJavascript;
import org.treesitter.TreeSitterJson;
import org.treesitter.TreeSitterKotlin;
import org.treesitter.TreeSitterLatex;
import org.treesitter.TreeSitterLua;
import org.treesitter.TreeSitterM68k;
import org.treesitter.TreeSitterMarkdown;
import org.treesitter.TreeSitterObjc;
import org.treesitter.TreeSitterPerl;
import org.treesitter.TreeSitterPhp;
import org.treesitter.TreeSitterPython;
import org.treesitter.TreeSitterRuby;
import org.treesitter.TreeSitterRust;
import org.treesitter.TreeSitterScala;
import org.treesitter.TreeSitterScss;
import org.treesitter.TreeSitterSvelte;
import org.treesitter.TreeSitterSwift;
import org.treesitter.TreeSitterTypescript;
import org.treesitter.TreeSitterYaml;
public class CodeCompletionParserFactory {
public static CodeCompletionParser getParserForFileExtension(String extension)
throws IllegalArgumentException {
return new CodeCompletionParser(getLanguageForExtension(extension));
}
private static TSLanguage getLanguageForExtension(String extension) {
switch (extension) {
case "java":
return new TreeSitterJava();
case "php":
return new TreeSitterPhp();
case "py":
return new TreeSitterPython();
case "ts":
case "tsx":
return new TreeSitterTypescript();
case "js":
case "jsx":
return new TreeSitterJavascript();
case "c":
case "h":
case "cpp":
case "cxx":
case "cc":
case "c++":
case "hpp":
case "hxx":
case "hh":
case "h++":
return new TreeSitterCpp();
case "cs":
return new TreeSitterCSharp();
case "css":
return new TreeSitterCss();
case "dart":
return new TreeSitterDart();
case "dockerfile":
return new TreeSitterDockerfile();
case "elixir":
case "ex":
case "exs":
return new TreeSitterElixir();
case "erl":
case "hrl":
return new TreeSitterErlang();
case "f90":
case "f95":
case "f03":
case "f08":
return new TreeSitterFortran();
case "gitattributes":
return new TreeSitterGitattributes();
case "go":
return new TreeSitterGo();
case "graphql":
case "gql":
return new TreeSitterGraphql();
case "html":
case "htm":
return new TreeSitterHtml();
case "json":
return new TreeSitterJson();
case "kotlin":
case "kt":
case "kts":
return new TreeSitterKotlin();
case "latex":
case "tex":
return new TreeSitterLatex();
case "lua":
return new TreeSitterLua();
case "m68k":
return new TreeSitterM68k();
case "markdown":
case "md":
return new TreeSitterMarkdown();
case "objc":
case "m":
case "mm":
return new TreeSitterObjc();
case "perl":
case "pl":
case "pm":
return new TreeSitterPerl();
case "ruby":
case "rb":
return new TreeSitterRuby();
case "rust":
case "rs":
return new TreeSitterRust();
case "scala":
case "sc":
return new TreeSitterScala();
case "scss":
return new TreeSitterScss();
case "svelte":
return new TreeSitterSvelte();
case "swift":
return new TreeSitterSwift();
case "yml":
case "yaml":
return new TreeSitterYaml();
default:
throw new IllegalArgumentException("Unsupported file extension: " + extension);
}
}
}

View file

@ -0,0 +1,81 @@
package ee.carlrobert.codegpt.treesitter;
import static org.assertj.core.api.Assertions.assertThat;
import org.junit.Test;
public class CodeCompletionParserTest {
@Test
public void shouldGetValidReturnValue() {
var prefix = "class Main {\n"
+ " public int getRandomNumber() {\n"
+ " return ";
var suffix = "\n"
+ " }\n"
+ "}";
var output = "10;}\n}\npublic int getRandomNumber(int k) {";
var parsedResponse = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);
assertThat(parsedResponse).isEqualTo("10;");
}
@Test
public void shouldGetValidParenthesisValue() {
var prefix = "class Main {\n"
+ " public int getRandomNumber(int ";
var suffix = ") {\n"
+ " return 10;\n"
+ " }\n"
+ "}";
var output = "prevNumber) {\n if() {";
var parsedResponse = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);
assertThat(parsedResponse).isEqualTo("prevNumber");
}
@Test
public void shouldHandleFieldDeclaration() {
var prefix = "class Main {\n"
+ "\t\n"
+ " private i";
var suffix = "\n"
+ "\n"
+ " public int getRandomNumber(int prevNumber) {\n"
+ " return Math.of()\n"
+ " }\n"
+ "}";
var output = "nt randomNumber;\n"
+ " \n"
+ " public void get() {";
var result = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);
assertThat(result).isEqualTo("nt randomNumber;");
}
@Test
public void shouldHandleFormalParameters() {
var prefix = "class Main {\n"
+ " public int getRandomNumber(";
var suffix = ") {\n"
+ " return 10;\n"
+ " }\n"
+ "}";
var output = "int prevNumber) }";
var result = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);
assertThat(result).isEqualTo("int prevNumber");
}
}