mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-21 19:13:38 +00:00
feat: support multi-line code completions
This commit is contained in:
parent
9251b12c9b
commit
a47ffd62e4
23 changed files with 646 additions and 345 deletions
|
|
@ -1,44 +1,116 @@
|
|||
package ee.carlrobert.codegpt.treesitter;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import org.treesitter.TSInputEdit;
|
||||
import org.treesitter.TSLanguage;
|
||||
import org.treesitter.TSParser;
|
||||
import org.treesitter.TSPoint;
|
||||
import org.treesitter.TSTree;
|
||||
|
||||
public class CodeCompletionParser {
|
||||
|
||||
protected final TSLanguage language;
|
||||
private final TSParser parser;
|
||||
|
||||
public CodeCompletionParser(TSLanguage language) {
|
||||
this.language = language;
|
||||
parser = new TSParser();
|
||||
parser.setLanguage(language);
|
||||
}
|
||||
|
||||
public String parse(String prefix, String suffix, String output) {
|
||||
var result = new StringBuilder(output);
|
||||
String input = prefix + result + suffix;
|
||||
TSTree currentTree = parser.parseString(null, input);
|
||||
|
||||
while (!result.isEmpty()) {
|
||||
if (containsError(prefix + result + suffix)) {
|
||||
if (containsError(currentTree)) {
|
||||
int deletionIndex = prefix.length() + result.length() - 1;
|
||||
Position pos = getPosition(input, deletionIndex);
|
||||
|
||||
int startByte = pos.byteOffset;
|
||||
int oldEndByte = startByte + getByteLength(result.substring(result.length() - 1));
|
||||
|
||||
TSPoint startPoint = pos.point;
|
||||
TSPoint oldEndPoint = computeOldEndPoint(startPoint, result.charAt(result.length() - 1));
|
||||
|
||||
currentTree.edit(
|
||||
new TSInputEdit(startByte, oldEndByte, startByte, startPoint, oldEndPoint, startPoint));
|
||||
|
||||
result.deleteCharAt(result.length() - 1);
|
||||
|
||||
if (result.length() > 1 && result.charAt(result.length() - 1) == '{') {
|
||||
long bracketCount = result.chars().filter(ch -> ch == '{').count();
|
||||
if (bracketCount == 1) {
|
||||
var newTree = parser.parseString(currentTree, prefix + result + "}" + suffix);
|
||||
var treeString = newTree.getRootNode().toString();
|
||||
if (!treeString.contains("ERROR")) {
|
||||
return result + "}";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input = prefix + result + suffix;
|
||||
|
||||
currentTree = parser.parseString(currentTree, input);
|
||||
} else {
|
||||
return result.toString();
|
||||
}
|
||||
}
|
||||
|
||||
if (output.contains("\n")) {
|
||||
return parse(prefix, suffix, output.substring(0, output.indexOf("\n")));
|
||||
var finalResult = output.substring(0, output.indexOf("\n"));
|
||||
if (finalResult.charAt(finalResult.length() - 1) == '{') {
|
||||
return finalResult + "}";
|
||||
}
|
||||
return finalResult;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
private boolean containsError(String input) {
|
||||
var treeString = getTree(input).getRootNode().toString();
|
||||
private boolean containsError(TSTree tree) {
|
||||
var treeString = tree.getRootNode().toString();
|
||||
return treeString.contains("ERROR")
|
||||
|| treeString.contains("MISSING \"}\"")
|
||||
|| treeString.contains("MISSING \")\"");
|
||||
}
|
||||
|
||||
private TSTree getTree(String input) {
|
||||
var parser = new TSParser();
|
||||
parser.setLanguage(language);
|
||||
return parser.parseString(null, input);
|
||||
private Position getPosition(String input, int index) {
|
||||
int row = 0;
|
||||
int col = 0;
|
||||
int byteOffset = 0;
|
||||
for (int i = 0; i < index; i++) {
|
||||
char c = input.charAt(i);
|
||||
int charByteLength = getByteLength(String.valueOf(c));
|
||||
byteOffset += charByteLength;
|
||||
|
||||
if (c == '\n') {
|
||||
row++;
|
||||
col = 0;
|
||||
} else {
|
||||
col++;
|
||||
}
|
||||
}
|
||||
return new Position(new TSPoint(row, col), byteOffset);
|
||||
}
|
||||
|
||||
private int getByteLength(String str) {
|
||||
return str.getBytes(StandardCharsets.UTF_8).length;
|
||||
}
|
||||
|
||||
private TSPoint computeOldEndPoint(TSPoint startPoint, char deletedChar) {
|
||||
int row = startPoint.getRow();
|
||||
int col = startPoint.getColumn();
|
||||
|
||||
if (deletedChar == '\n') {
|
||||
row++;
|
||||
col = 0;
|
||||
} else {
|
||||
col++;
|
||||
}
|
||||
return new TSPoint(row, col);
|
||||
}
|
||||
|
||||
private record Position(TSPoint point, int byteOffset) {
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,9 +38,7 @@ public class CodeCompletionParserTest {
|
|||
return 10;
|
||||
}
|
||||
}""";
|
||||
var output = """
|
||||
prevNumber) {
|
||||
if() {""";
|
||||
var output = "prevNumber);";
|
||||
|
||||
var parsedResponse = CodeCompletionParserFactory
|
||||
.getParserForFileExtension("java")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue