feat: support multi-line code completions

This commit is contained in:
Carl-Robert Linnupuu 2024-11-30 22:54:12 +00:00
parent 9251b12c9b
commit a47ffd62e4
23 changed files with 646 additions and 345 deletions

View file

@ -1,44 +1,116 @@
package ee.carlrobert.codegpt.treesitter;
import java.nio.charset.StandardCharsets;
import org.treesitter.TSInputEdit;
import org.treesitter.TSLanguage;
import org.treesitter.TSParser;
import org.treesitter.TSPoint;
import org.treesitter.TSTree;
public class CodeCompletionParser {
protected final TSLanguage language;
private final TSParser parser;
public CodeCompletionParser(TSLanguage language) {
this.language = language;
parser = new TSParser();
parser.setLanguage(language);
}
public String parse(String prefix, String suffix, String output) {
var result = new StringBuilder(output);
String input = prefix + result + suffix;
TSTree currentTree = parser.parseString(null, input);
while (!result.isEmpty()) {
if (containsError(prefix + result + suffix)) {
if (containsError(currentTree)) {
int deletionIndex = prefix.length() + result.length() - 1;
Position pos = getPosition(input, deletionIndex);
int startByte = pos.byteOffset;
int oldEndByte = startByte + getByteLength(result.substring(result.length() - 1));
TSPoint startPoint = pos.point;
TSPoint oldEndPoint = computeOldEndPoint(startPoint, result.charAt(result.length() - 1));
currentTree.edit(
new TSInputEdit(startByte, oldEndByte, startByte, startPoint, oldEndPoint, startPoint));
result.deleteCharAt(result.length() - 1);
if (result.length() > 1 && result.charAt(result.length() - 1) == '{') {
long bracketCount = result.chars().filter(ch -> ch == '{').count();
if (bracketCount == 1) {
var newTree = parser.parseString(currentTree, prefix + result + "}" + suffix);
var treeString = newTree.getRootNode().toString();
if (!treeString.contains("ERROR")) {
return result + "}";
}
}
}
input = prefix + result + suffix;
currentTree = parser.parseString(currentTree, input);
} else {
return result.toString();
}
}
if (output.contains("\n")) {
return parse(prefix, suffix, output.substring(0, output.indexOf("\n")));
var finalResult = output.substring(0, output.indexOf("\n"));
if (finalResult.charAt(finalResult.length() - 1) == '{') {
return finalResult + "}";
}
return finalResult;
}
return output;
}
private boolean containsError(String input) {
var treeString = getTree(input).getRootNode().toString();
private boolean containsError(TSTree tree) {
var treeString = tree.getRootNode().toString();
return treeString.contains("ERROR")
|| treeString.contains("MISSING \"}\"")
|| treeString.contains("MISSING \")\"");
}
private TSTree getTree(String input) {
var parser = new TSParser();
parser.setLanguage(language);
return parser.parseString(null, input);
private Position getPosition(String input, int index) {
int row = 0;
int col = 0;
int byteOffset = 0;
for (int i = 0; i < index; i++) {
char c = input.charAt(i);
int charByteLength = getByteLength(String.valueOf(c));
byteOffset += charByteLength;
if (c == '\n') {
row++;
col = 0;
} else {
col++;
}
}
return new Position(new TSPoint(row, col), byteOffset);
}
private int getByteLength(String str) {
return str.getBytes(StandardCharsets.UTF_8).length;
}
private TSPoint computeOldEndPoint(TSPoint startPoint, char deletedChar) {
int row = startPoint.getRow();
int col = startPoint.getColumn();
if (deletedChar == '\n') {
row++;
col = 0;
} else {
col++;
}
return new TSPoint(row, col);
}
private record Position(TSPoint point, int byteOffset) {
}
}

View file

@ -38,9 +38,7 @@ public class CodeCompletionParserTest {
return 10;
}
}""";
var output = """
prevNumber) {
if() {""";
var output = "prevNumber);";
var parsedResponse = CodeCompletionParserFactory
.getParserForFileExtension("java")