feat: add project context to code completions (#571)

* feat: add context to code completions

* feat: context finder for Python

* feat: improve and refactor context finder for Python

* feat: include method calls in JavaContextFinder

* test: add JavaContextFinder tests

* test: add PythonContextFinder tests

* fix: CompletionContextService thread

* fix: InfillPromptTemplate context files string

* refactor: simplify findRelevantElements for Java and Python

* feat: only add code snippets instead of files for code-completion context

* feat: add default multi-file prompt template

* fix: add Codestral multi-file FIM

* feat: add feature flag for context aware code completions

* feat: truncate project context elements for code completion
This commit is contained in:
Phil 2024-07-03 16:38:03 +02:00 committed by Carl-Robert Linnupuu
parent c193695fc1
commit 60d71cd301
24 changed files with 762 additions and 79 deletions

View file

@ -42,7 +42,7 @@ intellij {
pluginName.set(properties("pluginName"))
version.set(properties("platformVersion"))
type.set(properties("platformType"))
plugins.set(listOf("java", "Git4Idea"))
plugins.set(listOf("java", "PythonCore:233.14808.12", "Git4Idea"))
}
changelog {

View file

@ -48,6 +48,7 @@ public class ConfigurationComponent {
private final JBCheckBox methodNameGenerationCheckBox;
private final JBCheckBox autoFormattingCheckBox;
private final JBCheckBox autocompletionPostProcessingCheckBox;
private final JBCheckBox autocompletionContextAwareCheckBox;
private final JTextArea systemPromptTextArea;
private final JTextArea commitMessagePromptTextArea;
private final IntegerField maxTokensField;
@ -128,6 +129,10 @@ public class ConfigurationComponent {
CodeGPTBundle.get("configurationConfigurable.autocompletionPostProcessing.label"),
configuration.isAutocompletionPostProcessingEnabled()
);
autocompletionContextAwareCheckBox = new JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.autocompletionContextAwareCheckBox.label"),
configuration.isAutocompletionPostProcessingEnabled()
);
mainPanel = FormBuilder.createFormBuilder()
.addComponent(tablePanel)
@ -138,6 +143,7 @@ public class ConfigurationComponent {
.addComponent(methodNameGenerationCheckBox)
.addComponent(autoFormattingCheckBox)
.addComponent(autocompletionPostProcessingCheckBox)
.addComponent(autocompletionContextAwareCheckBox)
.addVerticalGap(4)
.addComponent(new TitledSeparator(
CodeGPTBundle.get("configurationConfigurable.section.assistant.title")))
@ -166,6 +172,7 @@ public class ConfigurationComponent {
state.setMethodNameGenerationEnabled(methodNameGenerationCheckBox.isSelected());
state.setAutoFormattingEnabled(autoFormattingCheckBox.isSelected());
state.setAutocompletionPostProcessingEnabled(autocompletionPostProcessingCheckBox.isSelected());
state.setAutocompletionContextAwareEnabled(autocompletionContextAwareCheckBox.isSelected());
return state;
}
@ -183,6 +190,8 @@ public class ConfigurationComponent {
autoFormattingCheckBox.setSelected(configuration.isAutoFormattingEnabled());
autocompletionPostProcessingCheckBox.setSelected(
configuration.isAutocompletionPostProcessingEnabled());
autocompletionContextAwareCheckBox.setSelected(
configuration.isAutocompletionContextAwareEnabled());
}
private Map<String, String> getTableData() {

View file

@ -21,6 +21,7 @@ public class ConfigurationState {
private boolean captureCompileErrors = true;
private boolean autoFormattingEnabled = true;
private boolean autocompletionPostProcessingEnabled = false;
private boolean autocompletionContextAwareEnabled = false;
private Map<String, String> tableData = EditorActionsUtil.DEFAULT_ACTIONS;
public String getSystemPrompt() {
@ -127,6 +128,14 @@ public class ConfigurationState {
this.autocompletionPostProcessingEnabled = autocompletionPostProcessingEnabled;
}
public boolean isAutocompletionContextAwareEnabled() {
return autocompletionContextAwareEnabled;
}
public void setAutocompletionContextAwareEnabled(boolean autocompletionContextAwareEnabled) {
this.autocompletionContextAwareEnabled = autocompletionContextAwareEnabled;
}
@Override
public boolean equals(Object o) {
if (this == o) {
@ -145,6 +154,7 @@ public class ConfigurationState {
&& captureCompileErrors == that.captureCompileErrors
&& autoFormattingEnabled == that.autoFormattingEnabled
&& autocompletionPostProcessingEnabled == that.autocompletionPostProcessingEnabled
&& autocompletionContextAwareEnabled == that.autocompletionContextAwareEnabled
&& Objects.equals(systemPrompt, that.systemPrompt)
&& Objects.equals(commitMessagePrompt, that.commitMessagePrompt)
&& Objects.equals(tableData, that.tableData);
@ -155,6 +165,7 @@ public class ConfigurationState {
return Objects.hash(systemPrompt, commitMessagePrompt, maxTokens, temperature,
checkForPluginUpdates, checkForNewScreenshots, createNewChatOnEachAction,
ignoreGitCommitTokenLimit, methodNameGenerationEnabled, captureCompileErrors,
autoFormattingEnabled, autocompletionPostProcessingEnabled, tableData);
autoFormattingEnabled, autocompletionPostProcessingEnabled,
autocompletionContextAwareEnabled, tableData);
}
}

View file

@ -1,6 +1,7 @@
package ee.carlrobert.codegpt.settings.service.llama.form;
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate;
import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails;
public class InfillPromptTemplatePanel extends BasePromptTemplatePanel<InfillPromptTemplate> {
@ -16,6 +17,6 @@ public class InfillPromptTemplatePanel extends BasePromptTemplatePanel<InfillPro
@Override
protected String buildPromptDescription(InfillPromptTemplate template) {
return template.buildPrompt("PREFIX", "SUFFIX");
return template.buildPrompt(new InfillRequestDetails("PREFIX", "SUFFIX", null));
}
}

View file

@ -12,7 +12,6 @@ import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettingsState
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import ee.carlrobert.llm.client.llama.completion.LlamaCompletionRequest
import ee.carlrobert.llm.client.ollama.completion.request.OllamaCompletionRequest
import ee.carlrobert.llm.client.ollama.completion.request.OllamaParameters
@ -97,7 +96,9 @@ object CodeCompletionRequestFactory {
fun buildLlamaRequest(details: InfillRequestDetails): LlamaCompletionRequest {
val settings = LlamaSettings.getCurrentState()
val promptTemplate = getLlamaInfillPromptTemplate(settings)
val prompt = promptTemplate.buildPrompt(details.prefix, details.suffix)
val prompt = promptTemplate.buildPrompt(details)
println("PROMPT: ")
println(prompt)
return LlamaCompletionRequest.Builder(prompt)
.setN_predict(getMaxTokens(details.prefix, details.suffix))
.setStream(true)
@ -110,7 +111,7 @@ object CodeCompletionRequestFactory {
val settings = service<OllamaSettings>().state
return OllamaCompletionRequest.Builder(
settings.model,
settings.fimTemplate.buildPrompt(details.prefix, details.suffix)
settings.fimTemplate.buildPrompt(details)
)
.setOptions(
OllamaParameters.Builder()
@ -140,7 +141,7 @@ object CodeCompletionRequestFactory {
): Any {
if (value !is String) return value
return when (value) {
"$" + Placeholder.FIM_PROMPT -> template.buildPrompt(details.prefix, details.suffix)
"$" + Placeholder.FIM_PROMPT -> template.buildPrompt(details)
"$" + Placeholder.PREFIX -> details.prefix
"$" + Placeholder.SUFFIX -> details.suffix
else -> value
@ -148,7 +149,9 @@ object CodeCompletionRequestFactory {
}
private fun getMaxTokens(prefix: String, suffix: String): Int {
if (isBoundaryCharacter(prefix[prefix.length - 1]) || isBoundaryCharacter(suffix[0])) {
if ((prefix.isNotEmpty() && isBoundaryCharacter(prefix[prefix.length - 1]))
|| (suffix.isNotEmpty() && isBoundaryCharacter(suffix[0]))
) {
return 16
}
return 36

View file

@ -17,6 +17,7 @@ import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.thisLogger
import com.intellij.openapi.util.TextRange
import ee.carlrobert.codegpt.CodeGPTKeys
import ee.carlrobert.codegpt.codecompletions.psi.CompletionContextService
import ee.carlrobert.codegpt.settings.GeneralSettings
import ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
@ -52,28 +53,39 @@ class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
get() = CodeCompletionSuggestionUpdateAdapter()
override suspend fun getSuggestion(request: InlineCompletionRequest): InlineCompletionSingleSuggestion {
val project = request.editor.project
val editor = request.editor
val project = editor.project
if (project == null) {
logger.error("Could not find project")
return InlineCompletionSingleSuggestion.build(elements = emptyFlow())
}
return InlineCompletionSingleSuggestion.build(elements = channelFlow {
val infillRequest = withContext(Dispatchers.EDT) {
InfillRequestDetails.fromInlineCompletionRequest(request)
}
val (prefix, suffix) = withContext(Dispatchers.EDT) {
val caretOffset = request.editor.caretModel.offset
val prefix =
request.document.getText(TextRange(0, caretOffset))
val suffix =
request.document.getText(
TextRange(
caretOffset,
request.document.textLength
val caretOffset = withContext(Dispatchers.EDT) { editor.caretModel.offset }
val infillContext =
if (service<ConfigurationSettings>().state.isAutocompletionContextAwareEnabled)
service<CompletionContextService>().findContext(editor, caretOffset)
else null
val infillRequest = if (infillContext == null) {
val (prefix, suffix) = withContext(Dispatchers.EDT) {
val prefix =
request.document.getText(TextRange(0, caretOffset))
val suffix =
request.document.getText(
TextRange(
caretOffset,
request.document.textLength
)
)
)
Pair(prefix, suffix)
Pair(prefix, suffix)
}
InfillRequestDetails.withoutContext(prefix, suffix)
} else {
// TODO: truncate contextElements if too long?
InfillRequestDetails.withContext(
infillContext,
caretOffset
)
}
currentCall.set(
@ -87,13 +99,14 @@ class CodeGPTInlineCompletionProvider : InlineCompletionProvider {
inlineText = CodeCompletionParserFactory
.getParserForFileExtension(request.file.virtualFile.extension)
.parse(
prefix,
suffix,
// TODO: ?
infillRequest.prefix,
infillRequest.suffix,
inlineText
)
}
request.editor.putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, inlineText)
editor.putUserData(CodeGPTKeys.PREVIOUS_INLAY_TEXT, inlineText)
launch {
try {
trySend(InlineCompletionGrayTextElement(inlineText))

View file

@ -3,48 +3,126 @@ package ee.carlrobert.codegpt.codecompletions
enum class InfillPromptTemplate(val label: String, val stopTokens: List<String>?) {
OPENAI("OpenAI", null) {
override fun buildPrompt(prefix: String, suffix: String): String {
return "<|fim_prefix|> $prefix <|fim_suffix|>$suffix <|fim_middle|>"
override fun buildPrompt(infillDetails: InfillRequestDetails): String {
val infillPrompt =
"<|fim_prefix|> ${infillDetails.prefix} <|fim_suffix|>${infillDetails.suffix} <|fim_middle|>"
return createDefaultMultiFilePrompt(infillDetails, infillPrompt)
}
},
CODE_LLAMA("Code Llama", listOf("<EOT>")) {
override fun buildPrompt(prefix: String, suffix: String): String {
return "<PRE> $prefix <SUF>$suffix <MID>"
override fun buildPrompt(infillDetails: InfillRequestDetails): String {
val infillPrompt = "<PRE> ${infillDetails.prefix} <SUF>${infillDetails.suffix} <MID>"
return createDefaultMultiFilePrompt(infillDetails, infillPrompt)
}
},
CODE_GEMMA(
"CodeGemma Instruct",
listOf("<|file_separator|>", "<|fim_prefix|>", "<|fim_suffix|>", "<|fim_middle|>", "<eos>")
) {
override fun buildPrompt(prefix: String, suffix: String): String {
return "<|fim_prefix|>$prefix<|fim_suffix|>$suffix<|fim_middle|>"
override fun buildPrompt(infillDetails: InfillRequestDetails): String {
// see https://huggingface.co/google/codegemma-7b#for-code-completion
val infillPrompt =
"<|fim_prefix|>${infillDetails.prefix}<|fim_suffix|>${infillDetails.suffix}<|fim_middle|>"
return if (infillDetails.context == null || infillDetails.context.contextElements.isEmpty()) {
infillPrompt
} else {
infillDetails.context.contextElements.map {
"<|file_separator|>${it.filePath()} \n" +
it.text()
}.joinToString("") { it + "\n" } +
"<|file_separator|>${infillDetails.context.enclosingElement.filePath()} \n" +
infillPrompt
}
}
},
CODE_QWEN("CodeQwen1.5", listOf("<|endoftext|>")) {
override fun buildPrompt(prefix: String, suffix: String): String {
return "<fim_prefix>$prefix<fim_suffix>$suffix<fim_middle>"
override fun buildPrompt(infillDetails: InfillRequestDetails): String {
// see https://github.com/QwenLM/CodeQwen1.5?tab=readme-ov-file#2-file-level-code-completion-fill-in-the-middle
val infillPrompt =
"<fim_prefix>${infillDetails.prefix}<fim_suffix>${infillDetails.suffix}<fim_middle>"
return if (infillDetails.context == null || infillDetails.context.contextElements.isEmpty()) {
infillPrompt
} else {
"<reponame>${infillDetails.context.getRepoName()}\n" +
infillDetails.context.contextElements.map {
"<file_sep>${it.filePath()} \n" +
it.text()
}.joinToString("") { it + "\n" } +
"<file_sep>${infillDetails.context.enclosingElement.filePath()} \n" +
infillPrompt
}
}
},
STABILITY("Stability AI", listOf("<|endoftext|>")) {
override fun buildPrompt(prefix: String, suffix: String): String {
return "<fim_prefix>$prefix<fim_suffix>$suffix<fim_middle>"
override fun buildPrompt(infillDetails: InfillRequestDetails): String {
val infillPrompt =
"<fim_prefix>${infillDetails.prefix}<fim_suffix>${infillDetails.suffix}<fim_middle>"
return createDefaultMultiFilePrompt(infillDetails, infillPrompt)
}
},
DEEPSEEK_CODER("DeepSeek Coder", listOf("<|EOT|>")) {
override fun buildPrompt(prefix: String, suffix: String): String {
return "<fim▁begin>$prefix<fim▁hole>$suffix<fim▁end>"
override fun buildPrompt(infillDetails: InfillRequestDetails): String {
// see https://github.com/deepseek-ai/DeepSeek-Coder?tab=readme-ov-file#2-code-insertion
val infillPrompt =
"<fim▁begin>${infillDetails.prefix}<fim▁hole>${infillDetails.suffix}<fim▁end>"
return if (infillDetails.context == null || infillDetails.context.contextElements.isEmpty()) {
infillPrompt
} else {
infillDetails.context.contextElements.map { "#${it.filePath()}\n" + it.text() }
.joinToString("") { it + "\n" } +
"#${infillDetails.context.enclosingElement.filePath()}\n" +
infillPrompt
}
}
},
STAR_CODER("StarCoder2", listOf("<|endoftext|>")) {
override fun buildPrompt(infillDetails: InfillRequestDetails): String {
// see https://huggingface.co/spaces/bigcode/bigcode-playground/blob/main/app.py
val infillPrompt =
"<fim_prefix>${infillDetails.prefix} <fim_suffix> ${infillDetails.suffix}<fim_middle>"
return if (infillDetails.context == null || infillDetails.context.contextElements.isEmpty()) {
infillPrompt
} else {
"<reponame>${infillDetails.context.getRepoName()}" +
infillDetails.context.contextElements.map {
"<filename>${it.filePath()}\n" +
it.text() + "<|endoftext|>"
}.joinToString("") { it + "\n" } +
"<filename>${infillDetails.context.enclosingElement.filePath()} \n" +
infillPrompt
}
}
},
CODESTRAL("Codestral", listOf("</s>")) {
override fun buildPrompt(prefix: String, suffix: String): String {
return "[SUFFIX]$suffix[PREFIX] $prefix"
override fun buildPrompt(infillDetails: InfillRequestDetails): String {
// see https://github.com/mistralai/mistral-common/blob/master/src/mistral_common/tokens/tokenizers/base.py
val infillPrompt = "[SUFFIX]$infillDetails.suffix[PREFIX] $infillDetails.prefix"
return createDefaultMultiFilePrompt(infillDetails, infillPrompt)
}
},
;
};
abstract fun buildPrompt(prefix: String, suffix: String): String
abstract fun buildPrompt(infillDetails: InfillRequestDetails): String
override fun toString(): String {
return label
}
companion object {
private fun createDefaultMultiFilePrompt(
infillDetails: InfillRequestDetails,
infillPrompt: String
): String {
val context = infillDetails.context
return if (context == null || context.contextElements.isEmpty()) {
infillPrompt
} else {
context.contextElements.map {
"# ${it.filePath()} \n" +
it.text()
}.joinToString("") { it + "\n" } +
"# ${context.enclosingElement.filePath()} \n" +
infillPrompt
}
}
}
}

View file

@ -1,55 +1,94 @@
package ee.carlrobert.codegpt.codecompletions
import com.intellij.codeInsight.inline.completion.InlineCompletionRequest
import com.intellij.openapi.editor.Document
import com.intellij.openapi.util.TextRange
import com.intellij.codeInsight.navigation.ImplementationSearcher
import com.intellij.psi.PsiElement
import com.intellij.refactoring.suggested.startOffset
import ee.carlrobert.codegpt.EncodingManager
import kotlin.math.max
import kotlin.math.min
import ee.carlrobert.codegpt.codecompletions.psi.filePath
import ee.carlrobert.codegpt.codecompletions.psi.readText
class InfillRequestDetails(val prefix: String, val suffix: String) {
class InfillRequestDetails(val prefix: String, val suffix: String, val context: InfillContext?) :
ImplementationSearcher() {
companion object {
private const val MAX_OFFSET = 10_000
private const val MAX_PROMPT_TOKENS = 128
private const val MAX_INFILL_PROMPT_TOKENS = 1_000
fun fromInlineCompletionRequest(request: InlineCompletionRequest): InfillRequestDetails {
return fromDocumentWithMaxOffset(
request.editor.document,
request.editor.caretModel.offset,
fun withoutContext(
prefix: String,
suffix: String
): InfillRequestDetails {
val truncatedPrefix = prefix.takeLast(MAX_OFFSET)
val truncatedSuffix = suffix.take(MAX_OFFSET)
return InfillRequestDetails(
truncateText(truncatedPrefix, false),
truncateText(truncatedSuffix, true),
null
)
}
private fun fromDocumentWithMaxOffset(
document: Document,
caretOffset: Int,
fun withContext(
infillContext: InfillContext,
caretOffsetInFile: Int,
): InfillRequestDetails {
val start = max(0, (caretOffset - MAX_OFFSET))
val end = min(document.textLength, (caretOffset + MAX_OFFSET))
return fromDocumentWithCustomRange(document, caretOffset, start, end)
val caretInEnclosingElement =
caretOffsetInFile - infillContext.enclosingElement.psiElement.startOffset
val entireText = infillContext.enclosingElement.psiElement.readText()
val prefix = truncateText(entireText.take(caretInEnclosingElement), false)
val suffix = truncateText(
if (entireText.length < caretInEnclosingElement) "" else entireText.takeLast(
entireText.length - caretInEnclosingElement
), true
)
return InfillRequestDetails(
prefix,
suffix,
truncateContext(prefix + suffix, infillContext)
)
}
private fun fromDocumentWithCustomRange(
document: Document,
caretOffset: Int,
start: Int,
end: Int,
): InfillRequestDetails {
val prefix: String = truncateText(document, start, caretOffset, false)
val suffix: String = truncateText(document, caretOffset, end, true)
return InfillRequestDetails(prefix, suffix)
private fun truncateContext(prompt: String, infillContext: InfillContext): InfillContext {
var promptTokens = EncodingManager.getInstance().countTokens(prompt)
val truncatedContextElements = infillContext.contextElements.takeWhile {
promptTokens += it.tokens
promptTokens <= MAX_INFILL_PROMPT_TOKENS
}.toSet()
return InfillContext(infillContext.enclosingElement, truncatedContextElements)
}
private fun truncateText(
document: Document,
start: Int,
end: Int,
text: String,
fromStart: Boolean
): String {
return EncodingManager.getInstance().truncateText(
document.getText(TextRange(start, end)),
text,
MAX_PROMPT_TOKENS,
fromStart
)
}
}
}
}
class InfillContext(
val enclosingElement: ContextElement,
// TODO: Add some kind of ranking, which contextElements are more important than others
val contextElements: Set<ContextElement>
) {
fun getRepoName(): String = enclosingElement.psiElement.project.name
}
class ContextElement {
val psiElement: PsiElement
var tokens: Int
constructor(psiElement: PsiElement) {
this.psiElement = psiElement
this.tokens = -1
}
fun filePath() = this.psiElement.filePath()
fun text() = this.psiElement.readText()
}

View file

@ -0,0 +1,48 @@
package ee.carlrobert.codegpt.codecompletions.psi
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.ReadAction
import com.intellij.openapi.components.Service
import com.intellij.openapi.editor.Editor
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiManager
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.codecompletions.InfillContext
@Service(Service.Level.PROJECT)
class CompletionContextService {
companion object {
private val CONTEXT_FINDERS = mapOf(
"JAVA" to JavaContextFinder::class.java,
"Python" to PythonContextFinder::class.java
)
}
/**
* Determines the [PsiElement] at the given offset,
* determines relevant context with the help of [LanguageContextFinder]s
* and returns the context with the relevant enclosing [PsiElement] and a set of source code [PsiElement]s.
*/
fun findContext(editor: Editor, offset: Int): InfillContext? {
return ReadAction.compute<InfillContext, Throwable> {
val psiFile = PsiManager.getInstance(editor.project!!).findFile(editor.virtualFile!!)!!
val psiElement = psiFile.findElementAt(offset) ?: return@compute null
val contextFinderClass = CONTEXT_FINDERS[psiElement.language.id]
?: // No context finder for the language implemented yet
return@compute null
val contextFinder = ApplicationManager.getApplication().getService(contextFinderClass)
?: // A context finder for the language exists but not available in the used IDE
return@compute null
val context = contextFinder.findContext(psiElement)
val encodingManager = EncodingManager.getInstance()
context.enclosingElement.tokens =
encodingManager.countTokens(context.enclosingElement.psiElement.text)
context.contextElements.forEach {
it.tokens = encodingManager.countTokens(it.psiElement.text)
}
return@compute context
}
}
}

View file

@ -0,0 +1,121 @@
package ee.carlrobert.codegpt.codecompletions.psi
import com.intellij.openapi.roots.JdkUtils
import com.intellij.psi.*
import com.intellij.psi.impl.source.PsiClassReferenceType
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.psi.util.PsiTypesUtil
import ee.carlrobert.codegpt.codecompletions.ContextElement
import ee.carlrobert.codegpt.codecompletions.InfillContext
import kotlinx.collections.immutable.toImmutableSet
class JavaContextFinder : LanguageContextFinder {
/**
* Finds enclosing [PsiMethod] or [PsiClass] of [psiElement] and
* determines source code files of all referenced classes or methods.
*/
override fun findContext(psiElement: PsiElement): InfillContext {
val enclosingElement = findEnclosingElement(psiElement)
val relevantElements = findRelevantElements(enclosingElement, enclosingElement)
val psiTargets = relevantElements.map { findPsiTarget(it) }.flatten().distinct()
val sourceElements = psiTargets.mapNotNull { findSourceElement(it) }
return InfillContext(
ContextElement(enclosingElement),
sourceElements.map { ContextElement(it) }.toSet()
)
}
private fun findEnclosingElement(psiElement: PsiElement): PsiElement =
findEnclosingContext(psiElement)
?: PsiTreeUtil.prevCodeLeaf(psiElement)?.let { findEnclosingContext(it) } ?: psiElement
fun findEnclosingContext(psiElement: PsiElement) =
PsiTreeUtil.findFirstContext(psiElement, true) { it is PsiMethod || it is PsiClass }
private fun findRelevantElements(
psiElement: Collection<PsiElement>,
rootElement: PsiElement
): Set<PsiElement> = psiElement.map { findRelevantElements(it, rootElement) }.flatten()
.distinctBy { it.text }.toSet()
/**
* Finds relevant [PsiTypeElement]s and [PsiMethodCallExpression]s that are used inside of [psiElement].
* If [psiElement] is a [PsiMethod] inside of a [PsiClass] it also adds all class and instance fields.
*/
fun findRelevantElements(psiElement: PsiElement, rootElement: PsiElement): Set<PsiElement> {
val resultSet = mutableSetOf<PsiElement>()
psiElement.accept(object : PsiRecursiveElementWalkingVisitor() {
override fun visitElement(element: PsiElement) {
when (element) {
is PsiTypeElement, is PsiMethodCallExpression -> resultSet.add(element)
is PsiMethod -> {
if (rootElement is PsiClass) {
// If the cursor was not inside a PsiMethod but inside a PsiClass, do not look into the method
return
}
val enclosingContext = findEnclosingContext(element)
if (enclosingContext is PsiClass) {
// add class and instance fields of enclosing class
// TODO: class fields declarations have to be present in the infillPrompt
// (same file as enclosingElement) as well
resultSet.addAll(
findRelevantElements(
(element.parent as PsiClass).allFields.toSet(),
rootElement
)
)
}
super.visitElement(element)
}
else -> super.visitElement(element)
}
}
})
return resultSet.distinctBy { it.text }.toImmutableSet()
}
/**
* Finds [PsiTarget]s to references used inside of [psiElement].
*/
private fun findPsiTarget(psiElement: PsiElement): Set<PsiTarget> {
return when (psiElement) {
is PsiTypeElement -> {
val type = psiElement.type
val clazz = PsiTypesUtil.getPsiClass(type) ?: return emptySet()
// Include generic types, e.g. String for List<String>
if (type is PsiClassReferenceType && type.parameters.isNotEmpty()) {
return setOf(clazz).plus(
type.parameters
.filterIsInstance<PsiClassReferenceType>()
.mapNotNull { it.resolve() }
)
}
setOf(clazz)
}
is PsiMethodCallExpression -> psiElement.resolveMethod()?.let { setOf(it) }
?: emptySet()
is PsiReferenceExpression -> {
val resolvedTarget = psiElement.resolve()
if (resolvedTarget is PsiTarget) setOf(resolvedTarget) else emptySet()
}
else -> emptySet()
}
}
private fun findSourceElement(psiTarget: PsiTarget): PsiElement? {
return if (psiTarget.canNavigateToSource()
&& JdkUtils.getJdkForElement(psiTarget.navigationElement) == null
) {
psiTarget.navigationElement
} else {
null
}
}
}

View file

@ -0,0 +1,11 @@
package ee.carlrobert.codegpt.codecompletions.psi
import com.intellij.psi.PsiElement
import ee.carlrobert.codegpt.codecompletions.InfillContext
interface LanguageContextFinder {
/**
* Determines relevant enclosing [PsiElement] and [PsiElement]s relevant to the context and returns their source code [PsiElement].
*/
fun findContext(psiElement: PsiElement): InfillContext
}

View file

@ -0,0 +1,14 @@
package ee.carlrobert.codegpt.codecompletions.psi
import com.intellij.openapi.application.ApplicationManager
import com.intellij.psi.PsiElement
fun PsiElement.filePath(): String {
return ApplicationManager.getApplication()
.runReadAction<String> { this.containingFile.virtualFile.path }
}
fun PsiElement.readText(): String {
return ApplicationManager.getApplication().runReadAction<String> { this.text }
}

View file

@ -0,0 +1,138 @@
package ee.carlrobert.codegpt.codecompletions.psi
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiRecursiveElementWalkingVisitor
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.psi.util.findParentOfType
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyBuiltinCache
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.PyResolveUtil
import com.jetbrains.python.psi.types.TypeEvalContext
import ee.carlrobert.codegpt.codecompletions.ContextElement
import ee.carlrobert.codegpt.codecompletions.InfillContext
class PythonContextFinder : LanguageContextFinder {
/**
* Finds enclosing [PyFunction] or [PyClass] of [psiElement] and
* determines source code elements of all used [PyReferenceExpression]s for the context.
*/
override fun findContext(psiElement: PsiElement): InfillContext {
val enclosingElement = findEnclosingElement(psiElement)
val referenceExpressions = findRelevantElements(enclosingElement, enclosingElement)
val declarations =
referenceExpressions.map { findDeclarations(it, psiElement.containingFile.project) }.flatten().distinct()
.filter {
// Filter out elements whose source code is inside the enclosingElement
// e.g. for something like this: [i for i in range(10)] findRelevantElements()
// would return a "PyReferenceExpression: i" which is irrelevant
!it.containingFile.equals(enclosingElement.containingFile) || !enclosingElement.textRange.contains(it.textRange)
}
val sourceElements = declarations.mapNotNull { findSourceElement(it) }
return InfillContext(
ContextElement(enclosingElement),
sourceElements.map { ContextElement(it) }.toSet()
)
}
fun findEnclosingElement(psiElement: PsiElement): PsiElement = findEnclosingContext(psiElement)
?: PsiTreeUtil.prevCodeLeaf(psiElement)?.let { findEnclosingContext(it) } ?: psiElement
private fun findEnclosingContext(psiElement: PsiElement) =
PsiTreeUtil.findFirstContext(psiElement, true) { it is PyFunction || it is PyClass }
private fun findRelevantElements(
psiElement: Collection<PsiElement>,
rootElement: PsiElement
): Set<PyReferenceExpression> =
psiElement.map { findRelevantElements(it, rootElement) }.flatten().distinctBy { it.name }
.toSet()
/**
* Finds [PyReferenceExpression]s inside of [psiElement].
* If [psiElement] is a [PyFunction] inside of a [PyClass] it also adds all [PyReferenceExpression] of any class/instance fields.
*/
fun findRelevantElements(
psiElement: PsiElement,
rootElement: PsiElement
): Set<PyReferenceExpression> {
val resultSet = mutableSetOf<PyReferenceExpression>()
psiElement.accept(object : PsiRecursiveElementWalkingVisitor() {
override fun visitElement(element: PsiElement) {
when (element) {
is PyReferenceExpression -> resultSet.add(element)
is PyFunction -> {
// If the cursor was not inside a PyFunction but inside a PyClass, do not look into the functions
if (rootElement is PyClass) {
return
}
val enclosingContext = findEnclosingContext(element)
if (enclosingContext is PyClass) {
// add class and instance fields of enclosing class
// TODO: class fields declarations have to be present in the infillPrompt
// (same file as enclosingElement) as well
resultSet.addAll(
findRelevantElements(enclosingContext.classAttributes.mapNotNull {
findTargetExpressionAssignment(
it
)
}, rootElement)
)
resultSet.addAll(
findRelevantElements(
enclosingContext.instanceAttributes.mapNotNull {
findTargetExpressionAssignment(it)
},
rootElement
)
)
}
super.visitElement(element)
}
else -> super.visitElement(element)
}
}
})
return resultSet.distinctBy { it.name }.toSet()
}
private fun findTargetExpressionAssignment(targetExpression: PyTargetExpression): PyExpression? {
return targetExpression.findParentOfType<PyAssignmentStatement>()?.assignedValue
}
private fun findDeclarations(
pyReference: PyReferenceExpression,
project: Project
): Set<PsiElement> {
// see https://github.com/JetBrains/intellij-community/blob/ae5290861a1f41b93c48c475239a52faa94b97b0/python/python-psi-impl/src/com/jetbrains/python/codeInsight/PyTargetElementEvaluator.java#L51-L54
return PyResolveUtil.resolveDeclaration(
pyReference.reference,
PyResolveContext.defaultContext(
TypeEvalContext.codeAnalysis(
project,
pyReference.containingFile
)
)
)?.let {
if (PyBuiltinCache.getInstance(pyReference).isBuiltin(it) || it.filePath().contains("/stdlib/")) {
null
} else {
setOf(it)
}
} ?: emptySet()
}
private fun findSourceElement(psiElement: PsiElement): PsiElement? {
val navigationElement = psiElement.navigationElement
val file = navigationElement.containingFile.virtualFile
return if (file.isInLocalFileSystem) {
navigationElement
} else {
null
}
}
}

View file

@ -9,6 +9,7 @@ import com.intellij.ui.components.JBLabel
import com.intellij.util.ui.FormBuilder
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate
import ee.carlrobert.codegpt.codecompletions.InfillRequestDetails
import org.apache.commons.text.StringEscapeUtils
import java.awt.FlowLayout
import javax.swing.Box
@ -63,7 +64,8 @@ class CodeCompletionConfigurationForm(
private fun updatePromptTemplateHelpTooltip(template: InfillPromptTemplate) {
promptTemplateHelpText.setToolTipText(null)
val description = StringEscapeUtils.escapeHtml4(template.buildPrompt("PREFIX", "SUFFIX"))
val description = StringEscapeUtils.escapeHtml4(template.buildPrompt(
InfillRequestDetails("PREFIX", "SUFFIX", null)))
HelpTooltip()
.setTitle(template.toString())
.setDescription("<html><p>$description</p></html>")

View file

@ -145,7 +145,7 @@ class CustomServiceCodeCompletionForm(
private fun testConnection() {
CompletionRequestService.getInstance().getCustomOpenAICompletionAsync(
CodeCompletionRequestFactory.buildCustomRequest(
InfillRequestDetails("Hello", "!"),
InfillRequestDetails("Hello", "!", null),
urlField.text,
tabbedPane.headers,
tabbedPane.body,
@ -186,7 +186,15 @@ class CustomServiceCodeCompletionForm(
private fun updatePromptTemplateHelpTooltip(template: InfillPromptTemplate) {
promptTemplateHelpText.setToolTipText(null)
val description = StringEscapeUtils.escapeHtml4(template.buildPrompt("PREFIX", "SUFFIX"))
val description = StringEscapeUtils.escapeHtml4(
template.buildPrompt(
InfillRequestDetails(
"PREFIX",
"SUFFIX",
null
)
)
)
HelpTooltip()
.setTitle(template.toString())
.setDescription("<html><p>$description</p></html>")

View file

@ -3,4 +3,9 @@
<listener topic="com.intellij.openapi.compiler.CompilationStatusListener"
class="ee.carlrobert.codegpt.ProjectCompilationStatusListener" />
</projectListeners>
<extensions defaultExtensionNs="com.intellij">
<applicationService
serviceImplementation="ee.carlrobert.codegpt.codecompletions.psi.JavaContextFinder"/>
</extensions>
</idea-plugin>

View file

@ -0,0 +1,6 @@
<idea-plugin>
<extensions defaultExtensionNs="com.intellij">
<applicationService
serviceImplementation="ee.carlrobert.codegpt.codecompletions.psi.PythonContextFinder"/>
</extensions>
</idea-plugin>

View file

@ -5,6 +5,13 @@
<depends>com.intellij.modules.platform</depends>
<depends>com.intellij.modules.lang</depends>
<depends optional="true" config-file="plugin-java.xml">com.intellij.modules.java</depends>
<depends optional="true" config-file="plugin-python.xml">com.intellij.modules.python</depends>
<!-- TODO-->
<!-- <depends optional="true" config-file="plugin-js.xml">JavaScript</depends>-->
<!-- <depends optional="true" config-file="plugin-go.xml">org.jetbrains.plugins.go</depends>-->
<!-- <depends optional="true" config-file="plugin-ruby.xml">com.intellij.modules.ruby</depends>-->
<!-- <depends optional="true" config-file="plugin-php.xml">com.jetbrains.php</depends>-->
<!-- <depends optional="true" config-file="plugin-swift.xml">com.intellij.swift</depends>-->
<depends optional="true">Git4Idea</depends>
<projectListeners>
@ -59,6 +66,7 @@
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.configuration.ConfigurationSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.settings.advanced.AdvancedSettings"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.conversations.ConversationsState"/>
<applicationService serviceImplementation="ee.carlrobert.codegpt.codecompletions.psi.CompletionContextService"/>
<inline.completion.provider
id="CodeGPTInlineCompletionProvider"
implementation="ee.carlrobert.codegpt.codecompletions.CodeGPTInlineCompletionProvider"/>

View file

@ -115,6 +115,7 @@ configurationConfigurable.openNewTabCheckBox.label=Open a new chat on each actio
configurationConfigurable.enableMethodNameGeneration.label=Enable method name lookup suggestions
configurationConfigurable.autoFormatting.label=Enable automatic code formatting
configurationConfigurable.autocompletionPostProcessing.label=Enable code completion post processing
configurationConfigurable.autocompletionContextAwareCheckBox.label=Enable project context aware code completion
configurationConfigurable.section.assistant.title=Assistant Configuration
configurationConfigurable.section.assistant.systemPromptField.label=System prompt:
configurationConfigurable.section.assistant.systemPromptField.comment=The system message helps to set the behaviour of the assistant

View file

@ -4,8 +4,6 @@ import com.intellij.codeInsight.inline.completion.session.InlineCompletionSessio
import com.intellij.openapi.editor.VisualPosition
import com.intellij.openapi.util.TextRange
import com.intellij.testFramework.PlatformTestUtil
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings
import ee.carlrobert.codegpt.util.file.FileUtil
import ee.carlrobert.llm.client.http.RequestEntity
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange
@ -39,7 +37,13 @@ class CodeCompletionServiceTest : IntegrationTest() {
assertThat(request.method).isEqualTo("POST")
assertThat(request.body)
.extracting("prompt")
.isEqualTo(InfillPromptTemplate.CODE_LLAMA.buildPrompt(prefix, suffix))
.isEqualTo(InfillPromptTemplate.CODE_LLAMA.buildPrompt(
InfillRequestDetails(
prefix,
suffix,
null
)
))
listOf(
jsonMapResponse(
e("content", expectedCompletion),

View file

@ -0,0 +1,63 @@
package ee.carlrobert.codegpt.codecompletions.psi
import com.intellij.openapi.editor.VisualPosition
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiJavaFile
import com.intellij.psi.PsiMethod
import com.intellij.testFramework.fixtures.BasePlatformTestCase
import ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent
import org.assertj.core.api.Assertions.assertThat
class JavaContextFinderTest : BasePlatformTestCase() {
private val contextFinder = JavaContextFinder()
fun testFindEnclosingContextMethod() {
val file = myFixture.configureByText(
"Util.java",
getResourceContent("/codecompletions/psi/java.txt")
)
val psiElement =
file.findElementAt(myFixture.editor.visualPositionToOffset(VisualPosition(15, 4)))
val enclosingElement = contextFinder.findEnclosingContext(psiElement!!)
assertThat(enclosingElement)
.isInstanceOf(PsiMethod::class.java)
.extracting("name")
.isEqualTo("randomStrings")
}
fun testFindEnclosingContextClass() {
val file = myFixture.configureByText(
"Util.java",
getResourceContent("/codecompletions/psi/java.txt")
)
val psiElement =
file.findElementAt(myFixture.editor.visualPositionToOffset(VisualPosition(13, 2)))
val contextFinder = contextFinder
val enclosingElement = contextFinder.findEnclosingContext(psiElement!!)
assertThat(enclosingElement)
.isInstanceOf(PsiClass::class.java)
.extracting("name")
.isEqualTo("Util")
}
fun testFindRelevantElements() {
val file = myFixture.configureByText(
"Util.java",
getResourceContent("/codecompletions/psi/java.txt")
)
val psiMethod =
(file as PsiJavaFile).classes[0].findMethodsByName("randomStrings", false)[0]
val relevantElements = contextFinder.findRelevantElements(psiMethod, psiMethod)
assertThat(relevantElements)
.hasSize(4)
.extracting("text")
.containsExactly(
"String",
"List<String>",
"int",
"IntStream.range(0, number).mapToObj(i -> Math.floor(100 * Math.random()) + \"\").toList()"
)
}
}

View file

@ -0,0 +1,67 @@
package ee.carlrobert.codegpt.codecompletions.psi
import com.intellij.openapi.editor.VisualPosition
import com.intellij.psi.PsiJavaFile
import com.intellij.testFramework.fixtures.BasePlatformTestCase
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyFile
import com.jetbrains.python.psi.PyFunction
import ee.carlrobert.codegpt.util.file.FileUtil.getResourceContent
import org.assertj.core.api.Assertions.assertThat
class PythonContextFinderTest : BasePlatformTestCase() {
private val contextFinder = PythonContextFinder()
fun testFindEnclosingElementMethod() {
val file = myFixture.configureByText(
"util.py",
getResourceContent("/codecompletions/psi/python.txt")
)
val psiElement =
file.findElementAt(myFixture.editor.visualPositionToOffset(VisualPosition(11, 1)))
val enclosingElement = contextFinder.findEnclosingElement(psiElement!!)
assertThat(enclosingElement)
.isInstanceOf(PyFunction::class.java)
.extracting("name")
.isEqualTo("randomStrings")
}
fun testFindEnclosingElementClass() {
val file = myFixture.configureByText(
"util.py",
getResourceContent("/codecompletions/psi/python.txt")
)
val psiElement =
file.findElementAt(myFixture.editor.visualPositionToOffset(VisualPosition(6, 2)))
val contextFinder = contextFinder
val enclosingElement = contextFinder.findEnclosingElement(psiElement!!)
assertThat(enclosingElement)
.isInstanceOf(PyClass::class.java)
.extracting("name")
.isEqualTo("Util")
}
fun testFindRelevantElements() {
val file = myFixture.configureByText(
"util.py",
getResourceContent("/codecompletions/psi/python.txt")
)
val psiMethod =
(file as PyFile).topLevelClasses[0].methods[1]
val relevantElements = contextFinder.findRelevantElements(psiMethod, psiMethod)
assertThat(relevantElements)
.hasSize(7)
.extracting("name")
.containsExactly(
"test",
"int",
"List",
"str",
"randint",
"range",
"n",
)
}
}

View file

@ -0,0 +1,19 @@
package ee.carlrobert.codegpt.actions;
import java.util.List;
import java.util.stream.IntStream;
public class Util {
private final String test;
public Util(String test) {
this.test = test;
}
public List<String> randomStrings(int number) {
return IntStream.range(0, number).mapToObj(i -> Math.floor(100 * Math.random()) + "").toList();
}
}

View file

@ -0,0 +1,14 @@
import random
from typing import List
class Util:
test: str = ""
def __init__(self, test: str):
self.test = test
def randomStrings(n: int) -> List[str]:
return [str(random.randint(0, 100)) for _ in range(n)]
}
}