Merge branch 'master' of github.com:AlexanderLuck/ProxyAI into AlexanderLuck/master

This commit is contained in:
Carl-Robert Linnupuu 2026-03-14 16:58:18 +00:00
commit 8fdeba74da
29 changed files with 1712 additions and 275 deletions

View file

@ -534,6 +534,13 @@ object AgentFactory {
hookManager = hookManager,
)
)
if (SubagentTool.DIAGNOSTICS in selected) tool(
DiagnosticsTool(
project = project,
sessionId = sessionId,
hookManager = hookManager,
)
)
if (SubagentTool.WEB_SEARCH in selected) tool(
WebSearchTool(
workingDirectory = project.basePath ?: System.getProperty("user.dir"),

View file

@ -23,6 +23,7 @@ import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.agent.clients.shouldStream
import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI
import ee.carlrobert.codegpt.agent.strategy.CODE_AGENT_COMPRESSION
import ee.carlrobert.codegpt.agent.strategy.HistoryCompressionConfig
import ee.carlrobert.codegpt.agent.strategy.SingleRunStrategyProvider
@ -34,7 +35,6 @@ import ee.carlrobert.codegpt.settings.hooks.HookManager
import ee.carlrobert.codegpt.settings.models.ModelSettings
import ee.carlrobert.codegpt.settings.service.FeatureType
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
import ee.carlrobert.codegpt.settings.skills.SkillDiscoveryService
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.BashPayload
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalRequest
@ -89,7 +89,7 @@ object ProxyAIAgent {
val modelSelection =
service<ModelSettings>().getModelSelectionForFeature(FeatureType.AGENT)
val skills = project.service<SkillDiscoveryService>().listSkills()
val stream = shouldStreamAgentToolLoop(project, provider)
val stream = shouldStreamAgentToolLoop(provider)
val projectInstructions = loadProjectInstructions(project.basePath)
val executor = AgentFactory.createExecutor(provider, events)
val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() }
@ -268,18 +268,10 @@ object ProxyAIAgent {
}
private fun shouldStreamAgentToolLoop(
project: Project,
provider: ServiceType,
): Boolean {
return when (provider) {
ServiceType.CUSTOM_OPENAI -> {
val selectedServiceId =
project.service<ModelSettings>().getStoredModelForFeature(FeatureType.AGENT)
project.service<CustomServicesSettings>().state.services
.firstOrNull { it.id == selectedServiceId }?.chatCompletionSettings?.shouldStream()
?: false
}
ServiceType.CUSTOM_OPENAI -> shouldStreamCustomOpenAI(FeatureType.AGENT)
ServiceType.GOOGLE -> false
else -> true
}
@ -323,6 +315,13 @@ object ProxyAIAgent {
hookManager = hookManager,
)
)
tool(
DiagnosticsTool(
project = project,
sessionId = sessionId,
hookManager = hookManager,
)
)
tool(
WebSearchTool(
workingDirectory = workingDirectory,

View file

@ -4,6 +4,7 @@ enum class SubagentTool(val id: String, val displayName: String, val isWrite: Bo
READ("read", "Read", false),
TODO_WRITE("todowrite", "TodoWrite", false),
INTELLIJ_SEARCH("intellijsearch", "IntelliJSearch", false),
DIAGNOSTICS("diagnostics", "Diagnostics", false),
WEB_SEARCH("websearch", "WebSearch", false),
WEB_FETCH("webfetch", "WebFetch", false),
MCP("MCP", "MCP", false),

View file

@ -0,0 +1,47 @@
package ee.carlrobert.codegpt.agent
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.contentOrNull
private val toolArgumentsJson = Json {
ignoreUnknownKeys = true
}
internal fun normalizeToolArgumentsJson(rawArgs: String?): String? {
val payload = rawArgs?.trim().orEmpty()
if (payload.isBlank()) {
return null
}
val element = runCatching {
toolArgumentsJson.parseToJsonElement(payload)
}.getOrNull() ?: return null
return normalizeToolArgumentsElement(element)?.toString()
}
private tailrec fun normalizeToolArgumentsElement(element: JsonElement): JsonObject? {
return when (element) {
is JsonObject -> element
is JsonPrimitive -> {
if (!element.isString) {
null
} else {
val nestedPayload = element.contentOrNull?.trim().orEmpty()
if (nestedPayload.isBlank()) {
null
} else {
val nested = runCatching {
toolArgumentsJson.parseToJsonElement(nestedPayload)
}.getOrNull() ?: return null
normalizeToolArgumentsElement(nested)
}
}
}
else -> null
}
}

View file

@ -14,6 +14,7 @@ enum class ToolName(val id: String, val aliases: Set<String> = emptySet()) {
BASH_OUTPUT("BashOutput"),
KILL_SHELL("KillShell"),
INTELLIJ_SEARCH("IntelliJSearch"),
DIAGNOSTICS("Diagnostics"),
WEB_SEARCH("WebSearch"),
WEB_FETCH("WebFetch"),
MCP("MCP"),
@ -94,6 +95,13 @@ object ToolSpecs {
IntelliJSearchTool.Result.serializer()
)
)
register(
ToolSpec(
ToolName.DIAGNOSTICS,
DiagnosticsTool.Args.serializer(),
DiagnosticsTool.Result.serializer()
)
)
register(
ToolSpec(
ToolName.WEB_SEARCH,
@ -193,8 +201,12 @@ object ToolSpecs {
if (serializer == null || payload.isBlank()) {
return null
}
val typedSerializer = serializer as KSerializer<Any>
return runCatching {
json.decodeFromString(serializer as KSerializer<Any>, payload)
json.decodeFromString(typedSerializer, payload)
}.recoverCatching {
val normalized = normalizeToolArgumentsJson(payload) ?: throw it
json.decodeFromString(typedSerializer, normalized)
}.getOrNull()
}
}

View file

@ -21,6 +21,7 @@ import com.fasterxml.jackson.core.type.TypeReference
import com.fasterxml.jackson.databind.ObjectMapper
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate
import ee.carlrobert.codegpt.codecompletions.InfillRequest
import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson
import ee.carlrobert.codegpt.settings.Placeholder
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceCodeCompletionSettingsState
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
@ -40,6 +41,7 @@ import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.descriptors.elementNames
import kotlinx.serialization.json.*
import java.net.URI
import java.util.UUID
import kotlin.io.encoding.ExperimentalEncodingApi
import kotlin.time.Clock
@ -193,7 +195,7 @@ class CustomOpenAILLMClient(
}
if (isResponsesApi) {
return serializeResponsesApiRequest(state, messages, model, tools)
return serializeResponsesApiRequest(state, messages, model, tools, toolChoice)
}
val customParams: CustomOpenAIParams = params.toCustomOpenAIParams(state)
@ -245,9 +247,12 @@ class CustomOpenAILLMClient(
state: CustomServiceChatCompletionSettingsState,
messages: List<OpenAIMessage>,
model: LLModel,
tools: List<OpenAITool>?
tools: List<OpenAITool>?,
toolChoice: OpenAIToolChoice?
): String {
val streamRequest = state.shouldStream()
val inputJson = messages.toResponsesApiItemsJson()
val prompt = renderCustomOpenAIPrompt(messages, json)
return buildJsonObject {
state.body.forEach { (key, value) ->
@ -256,7 +261,8 @@ class CustomOpenAILLMClient(
key = key,
value = value,
streamRequest = streamRequest,
messages = messages,
messagesJson = inputJson,
prompt = prompt,
credential = apiKey,
json = json
)
@ -264,8 +270,9 @@ class CustomOpenAILLMClient(
}
put("model", JsonPrimitive(model.id))
if (!tools.isNullOrEmpty()) {
put("tools", json.encodeToJsonElement(ListSerializer(OpenAITool.serializer()), tools))
put("tools", JsonArray(tools.map { it.toResponsesApiToolJson() }))
}
toolChoice?.toResponsesApiToolChoiceJson()?.let { put("tool_choice", it) }
}.toString()
}
@ -418,10 +425,17 @@ class CustomOpenAILLMClient(
}
override fun decodeStreamingResponse(data: String): CustomOpenAIChatCompletionStreamResponse {
val payload = normalizeSsePayload(data)
?: return CustomOpenAIChatCompletionStreamResponse(
choices = emptyList(),
created = 0,
id = "",
model = ""
)
if (!isResponsesApi) {
return json.decodeFromString(data)
return json.decodeFromString(payload)
}
return adaptResponsesApiStreamEvent(data)
return adaptResponsesApiStreamEvent(payload)
}
override fun decodeResponse(data: String): CustomOpenAIChatCompletionResponse {
@ -448,48 +462,42 @@ class CustomOpenAILLMClient(
)
}
"response.function_call_arguments.delta" -> {
val argsDelta = event["delta"]?.jsonPrimitive?.contentOrNull ?: ""
val callId = event["call_id"]?.jsonPrimitive?.contentOrNull ?: ""
val index = event["output_index"]?.jsonPrimitive?.intOrNull ?: 0
CustomOpenAIStreamChoice(
delta = CustomOpenAIStreamDelta(
toolCalls = listOf(
CustomOpenAIToolCall(
id = callId,
index = index,
function = CustomOpenAIFunction(arguments = argsDelta)
"response.function_call_arguments.delta",
"response.output_item.added" -> null
"response.completed" -> {
val responseObject = event["response"]?.jsonObject
val toolCalls = responseObject?.get("output")
?.jsonArray
?.mapIndexedNotNull { index, item ->
val itemObject = item.jsonObject
if (itemObject["type"]?.jsonPrimitive?.contentOrNull != "function_call") {
return@mapIndexedNotNull null
}
val rawArguments = when (val arguments = itemObject["arguments"]) {
is JsonPrimitive -> arguments.contentOrNull
is JsonObject -> arguments.toString()
else -> null
}
CustomOpenAIToolCall(
id = itemObject["call_id"]?.jsonPrimitive?.contentOrNull,
index = index,
function = CustomOpenAIFunction(
name = itemObject["name"]?.jsonPrimitive?.contentOrNull,
arguments = normalizeToolArgumentsJson(rawArguments) ?: rawArguments.orEmpty()
)
)
}
.orEmpty()
CustomOpenAIStreamChoice(
finishReason = "stop",
delta = CustomOpenAIStreamDelta(
toolCalls = toolCalls.takeIf { it.isNotEmpty() }
)
)
}
"response.output_item.added" -> {
val item = event["item"]?.jsonObject
if (item?.get("type")?.jsonPrimitive?.contentOrNull == "function_call") {
val callId = item["call_id"]?.jsonPrimitive?.contentOrNull ?: ""
val name = item["name"]?.jsonPrimitive?.contentOrNull ?: ""
val index = event["output_index"]?.jsonPrimitive?.intOrNull ?: 0
CustomOpenAIStreamChoice(
delta = CustomOpenAIStreamDelta(
toolCalls = listOf(
CustomOpenAIToolCall(
id = callId,
index = index,
function = CustomOpenAIFunction(name = name)
)
)
)
)
} else null
}
"response.completed" -> CustomOpenAIStreamChoice(
finishReason = "stop",
delta = CustomOpenAIStreamDelta()
)
else -> null
}
@ -526,12 +534,20 @@ class CustomOpenAILLMClient(
}
"function_call" -> {
val rawArguments = when (val arguments = itemObj["arguments"]) {
is JsonPrimitive -> arguments.contentOrNull
is JsonObject -> arguments.toString()
else -> null
}
toolCallsJson.add(buildJsonObject {
put("id", itemObj["call_id"] ?: JsonPrimitive(""))
put("type", JsonPrimitive("function"))
putJsonObject("function") {
put("name", itemObj["name"] ?: JsonPrimitive(""))
put("arguments", itemObj["arguments"] ?: JsonPrimitive("{}"))
put(
"arguments",
JsonPrimitive(normalizeToolArgumentsJson(rawArguments) ?: rawArguments ?: "{}")
)
}
})
}
@ -631,12 +647,12 @@ class CustomOpenAILLMClient(
}
assistantToolCalls.forEach { toolCall ->
val arguments = normalizeToolArgumentsJson(toolCall.function.arguments) ?: "{}"
add(
Message.Tool.Call(
id = toolCall.id,
tool = toolCall.function.name,
content = toolCall.function.arguments.takeIf { it.isNotEmpty() }
?: "{}",
content = arguments,
metaInfo = metaInfo
)
)
@ -706,22 +722,38 @@ internal fun transformCustomOpenAIBodyValue(
messages: List<OpenAIMessage>,
credential: String,
json: Json
): JsonElement {
return transformCustomOpenAIBodyValue(
key = key,
value = value,
streamRequest = streamRequest,
messagesJson = json.encodeToJsonElement(
ListSerializer(OpenAIMessage.serializer()),
messages
),
prompt = renderCustomOpenAIPrompt(messages, json),
credential = credential,
json = json
)
}
private fun transformCustomOpenAIBodyValue(
key: String?,
value: Any?,
streamRequest: Boolean,
messagesJson: JsonElement,
prompt: String,
credential: String,
json: Json
): JsonElement {
return when (value) {
null -> JsonNull
is JsonElement -> value
is String -> when {
!streamRequest && key == "stream" -> JsonPrimitive(false)
CustomServicePlaceholders.isMessages(value) -> json.parseToJsonElement(
json.encodeToString(ListSerializer(OpenAIMessage.serializer()), messages)
)
CustomServicePlaceholders.isMessages(value) -> messagesJson
CustomServicePlaceholders.isPrompt(value) -> JsonPrimitive(
renderCustomOpenAIPrompt(
messages,
json
)
)
CustomServicePlaceholders.isPrompt(value) -> JsonPrimitive(prompt)
value.contains($$"$CUSTOM_SERVICE_API_KEY") -> {
JsonPrimitive(value.replace($$"$CUSTOM_SERVICE_API_KEY", credential))
@ -740,7 +772,8 @@ internal fun transformCustomOpenAIBodyValue(
key = nestedKey.toString(),
value = nestedValue,
streamRequest = streamRequest,
messages = messages,
messagesJson = messagesJson,
prompt = prompt,
credential = credential,
json = json
)
@ -753,7 +786,8 @@ internal fun transformCustomOpenAIBodyValue(
key = null,
value = item,
streamRequest = streamRequest,
messages = messages,
messagesJson = messagesJson,
prompt = prompt,
credential = credential,
json = json
)
@ -766,7 +800,8 @@ internal fun transformCustomOpenAIBodyValue(
key = null,
value = item,
streamRequest = streamRequest,
messages = messages,
messagesJson = messagesJson,
prompt = prompt,
credential = credential,
json = json
)
@ -784,6 +819,174 @@ internal fun renderCustomOpenAIPrompt(messages: List<OpenAIMessage>, json: Json)
}
}
private fun List<OpenAIMessage>.toResponsesApiItemsJson(): JsonArray {
return JsonArray(
buildList {
for (message in this@toResponsesApiItemsJson) {
addAll(message.toResponsesApiItemsJson())
}
}
)
}
private fun OpenAIMessage.toResponsesApiItemsJson(): List<JsonObject> {
return when (this) {
is OpenAIMessage.System, is OpenAIMessage.Developer -> listOf(
buildResponsesApiInputMessageItemJson(
role = "developer",
content = content.toResponsesApiInputContentJson()
)
)
is OpenAIMessage.User -> listOf(
buildResponsesApiInputMessageItemJson(
role = "user",
content = content.toResponsesApiInputContentJson()
)
)
is OpenAIMessage.Assistant -> buildList {
reasoningContent
?.takeIf { it.isNotBlank() }
?.let { reasoning ->
add(
buildJsonObject {
put("type", JsonPrimitive("reasoning"))
put("id", JsonPrimitive(UUID.randomUUID().toString()))
putJsonArray("summary") {
addJsonObject {
put("type", JsonPrimitive("summary_text"))
put("text", JsonPrimitive(reasoning))
}
}
}
)
}
content?.text()
?.takeIf { it.isNotBlank() }
?.let { assistantText ->
add(
buildJsonObject {
put("type", JsonPrimitive("message"))
put("role", JsonPrimitive("assistant"))
putJsonArray("content") {
addJsonObject {
put("type", JsonPrimitive("output_text"))
put("text", JsonPrimitive(assistantText))
put("annotations", JsonArray(emptyList()))
}
}
}
)
}
toolCalls.orEmpty().forEach { toolCall ->
val arguments = normalizeToolArgumentsJson(toolCall.function.arguments) ?: "{}"
add(
buildJsonObject {
put("type", JsonPrimitive("function_call"))
put("arguments", JsonPrimitive(arguments))
put("call_id", JsonPrimitive(toolCall.id))
put("name", JsonPrimitive(toolCall.function.name))
}
)
}
}
is OpenAIMessage.Tool -> listOf(
buildJsonObject {
put("type", JsonPrimitive("function_call_output"))
put("call_id", JsonPrimitive(toolCallId))
put("output", JsonPrimitive(content?.text().orEmpty()))
}
)
}
}
private fun buildResponsesApiInputMessageItemJson(
role: String,
content: JsonArray
): JsonObject {
return buildJsonObject {
put("type", JsonPrimitive("message"))
put("role", JsonPrimitive(role))
put("content", content)
}
}
private fun Content?.toResponsesApiInputContentJson(): JsonArray {
if (this == null) {
return JsonArray(emptyList())
}
return when (this) {
is Content.Parts -> JsonArray(value.mapNotNull { it.toResponsesApiInputContentJson() })
else -> JsonArray(
listOf(
buildJsonObject {
put("type", JsonPrimitive("input_text"))
put("text", JsonPrimitive(text()))
}
)
)
}
}
private fun OpenAIContentPart.toResponsesApiInputContentJson(): JsonObject? {
return when (this) {
is OpenAIContentPart.Text -> buildJsonObject {
put("type", JsonPrimitive("input_text"))
put("text", JsonPrimitive(text))
}
is OpenAIContentPart.Image -> buildJsonObject {
put("type", JsonPrimitive("input_image"))
imageUrl.detail?.let { put("detail", JsonPrimitive(it)) }
put("imageUrl", JsonPrimitive(imageUrl.url))
}
is OpenAIContentPart.File -> buildJsonObject {
put("type", JsonPrimitive("input_file"))
file.fileData?.let { put("fileData", JsonPrimitive(it)) }
file.fileId?.let { put("fileId", JsonPrimitive(it)) }
file.filename?.let { put("filename", JsonPrimitive(it)) }
}
else -> null
}
}
internal fun OpenAITool.toResponsesApiToolJson(): JsonObject {
return buildJsonObject {
put("type", JsonPrimitive("function"))
put("name", JsonPrimitive(function.name))
put(
"parameters",
function.parameters ?: buildJsonObject {
put("type", JsonPrimitive("object"))
putJsonObject("properties") {}
putJsonArray("required") {}
}
)
function.strict?.let { put("strict", JsonPrimitive(it)) }
function.description?.takeIf { it.isNotBlank() }?.let {
put("description", JsonPrimitive(it))
}
}
}
internal fun OpenAIToolChoice.toResponsesApiToolChoiceJson(): JsonElement {
return when (this) {
is OpenAIToolChoice.Function -> buildJsonObject {
put("type", JsonPrimitive("function"))
put("name", JsonPrimitive(function.name))
}
is OpenAIToolChoice.Mode -> JsonPrimitive(value)
}
}
internal object CustomOpenAIChatCompletionRequestSerializer :
CustomOpenAIAdditionalPropertiesFlatteningSerializer(CustomOpenAIChatCompletionRequest.serializer())

View file

@ -0,0 +1,14 @@
package ee.carlrobert.codegpt.agent.clients
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.settings.service.FeatureType
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
internal fun shouldStreamCustomOpenAI(featureType: FeatureType): Boolean {
return runCatching {
service<CustomServicesSettings>()
.customServiceStateForFeatureType(featureType)
.chatCompletionSettings
.shouldStream()
}.getOrDefault(false)
}

View file

@ -180,8 +180,11 @@ public class ProxyAILLMClient(
}
}
override fun decodeStreamingResponse(data: String): ProxyAIChatCompletionStreamResponse =
json.decodeFromString(data)
override fun decodeStreamingResponse(data: String): ProxyAIChatCompletionStreamResponse {
val payload = normalizeSsePayload(data)
?: return ProxyAIChatCompletionStreamResponse()
return json.decodeFromString(payload)
}
override fun decodeResponse(data: String): ProxyAIChatCompletionResponse =
json.decodeFromString(data)

View file

@ -0,0 +1,31 @@
package ee.carlrobert.codegpt.agent.clients
internal fun normalizeSsePayload(rawData: String): String? {
val normalized = rawData
.replace("\r\n", "\n")
.replace('\r', '\n')
.trim()
if (normalized.isEmpty()) {
return null
}
val dataLines = normalized.lineSequence()
.mapNotNull { line ->
val markerIndex = line.indexOf("data:")
if (markerIndex >= 0) {
line.substring(markerIndex + "data:".length).trimStart()
} else {
null
}
}
.filter { it.isNotEmpty() }
.toList()
val payload = if (dataLines.isNotEmpty()) {
dataLines.joinToString("\n")
} else {
normalized
}
return payload.takeUnless { it.isBlank() || it == "[DONE]" }
}

View file

@ -23,15 +23,13 @@ import com.intellij.openapi.vfs.LocalFileSystem
import ee.carlrobert.codegpt.ReferencedFile
import ee.carlrobert.codegpt.agent.AgentEvents
import ee.carlrobert.codegpt.agent.MessageWithContext
import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson
import ee.carlrobert.codegpt.agent.credits.extractCreditsSnapshot
import ee.carlrobert.codegpt.completions.CompletionRequestUtil
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent
import ee.carlrobert.codegpt.ui.textarea.TagProcessorFactory
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagDetails
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonObject
import java.util.*
import ee.carlrobert.codegpt.conversations.message.Message as ChatMessage
@ -186,15 +184,11 @@ private suspend fun AIAgentLLMWriteSession.requestResponses(
.toMessageResponses()
.map {
if (it is Message.Tool.Call) {
try {
// validate json
Json.parseToJsonElement(it.content).jsonObject
val normalizedArgs = normalizeToolArgumentsJson(it.content) ?: "{}"
if (normalizedArgs == it.content) {
it
} catch (_: SerializationException) {
// allows agent to retry the request
it.copy(parts = listOf(it.parts[0].copy(text = "{}")))
} catch (e: Exception) {
throw e
} else {
it.copy(parts = listOf(it.parts[0].copy(text = normalizedArgs)))
}
} else {
it

View file

@ -0,0 +1,123 @@
package ee.carlrobert.codegpt.agent.tools
import ai.koog.agents.core.tools.annotations.LLMDescription
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
import ee.carlrobert.codegpt.settings.ToolPermissionPolicy
import ee.carlrobert.codegpt.settings.hooks.HookManager
import ee.carlrobert.codegpt.tokens.truncateToolResult
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
class DiagnosticsTool(
private val project: Project,
private val sessionId: String,
private val hookManager: HookManager,
) : BaseTool<DiagnosticsTool.Args, DiagnosticsTool.Result>(
workingDirectory = project.basePath ?: System.getProperty("user.dir"),
argsSerializer = Args.serializer(),
resultSerializer = Result.serializer(),
name = "Diagnostics",
description = """
Reads the IDE's current diagnostics for a specific file.
Use this tool when you need compiler or inspection diagnostics for one file.
- The file_path parameter must be an absolute path.
- filter='errors_only' returns only errors.
- filter='all' returns errors, warnings, weak warnings, and info diagnostics.
- Results reflect diagnostics currently available in the IDE for that file.
""".trimIndent(),
argsClass = Args::class,
resultClass = Result::class,
hookManager = hookManager,
sessionId = sessionId,
) {
@Serializable
data class Args(
@property:LLMDescription(
"The absolute path to the file to inspect. Must be an absolute path."
)
@SerialName("file_path")
val filePath: String,
@property:LLMDescription(
"Diagnostics filter: 'errors_only' for errors only, or 'all' for all diagnostics."
)
val filter: DiagnosticsFilter = DiagnosticsFilter.ERRORS_ONLY
)
@Serializable
data class Result(
@SerialName("file_path")
val filePath: String,
val filter: DiagnosticsFilter,
@SerialName("diagnostic_count")
val diagnosticCount: Int = 0,
val output: String = "",
val error: String? = null
)
override suspend fun doExecute(args: Args): Result {
val settingsService = project.service<ProxyAISettingsService>()
val decision = settingsService.evaluateToolPermission(this, args.filePath)
if (decision == ToolPermissionPolicy.Decision.DENY) {
return Result(
filePath = args.filePath,
filter = args.filter,
error = "Access denied by permissions.deny for Diagnostics"
)
}
if (settingsService.hasAllowRulesForTool("Diagnostics")
&& decision != ToolPermissionPolicy.Decision.ALLOW
) {
return Result(
filePath = args.filePath,
filter = args.filter,
error = "Access denied by permissions.allow for Diagnostics"
)
}
if (settingsService.isPathIgnored(args.filePath)) {
return Result(
filePath = args.filePath,
filter = args.filter,
error = "File not found: ${args.filePath}"
)
}
val diagnosticsService = project.service<ProjectDiagnosticsService>()
val virtualFile = diagnosticsService.findVirtualFile(args.filePath)
?: return Result(
filePath = args.filePath,
filter = args.filter,
error = "File not found: ${args.filePath}"
)
val report = diagnosticsService.collect(virtualFile, args.filter)
return Result(
filePath = args.filePath,
filter = args.filter,
diagnosticCount = report.diagnosticCount,
output = report.content.ifBlank { args.filter.emptyMessage() },
error = report.error
)
}
override fun createDeniedResult(originalArgs: Args, deniedReason: String): Result {
return Result(
filePath = originalArgs.filePath,
filter = originalArgs.filter,
error = deniedReason
)
}
override fun encodeResultToString(result: Result): String {
if (result.error != null) {
return "Failed to read diagnostics for '${result.filePath}': ${result.error}"
.truncateToolResult()
}
return result.output.truncateToolResult()
}
}

View file

@ -8,21 +8,19 @@ import ai.koog.agents.features.eventHandler.feature.handleEvents
import ai.koog.agents.features.tokenizer.feature.MessageTokenizer
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.tokenizer.Tokenizer
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.agent.AgentEvents
import ee.carlrobert.codegpt.agent.MessageWithContext
import ee.carlrobert.codegpt.agent.clients.shouldStream
import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI
import ee.carlrobert.codegpt.agent.strategy.CODE_AGENT_COMPRESSION
import ee.carlrobert.codegpt.agent.strategy.HistoryCompressionConfig
import ee.carlrobert.codegpt.agent.strategy.SingleRunStrategyProvider
import ee.carlrobert.codegpt.mcp.McpTool
import ee.carlrobert.codegpt.mcp.McpToolAliasResolver
import ee.carlrobert.codegpt.mcp.McpToolCallHandler
import ee.carlrobert.codegpt.settings.models.ModelSettings
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
import ee.carlrobert.codegpt.util.ReasoningFrameTextAdapter
import kotlinx.coroutines.*
import java.util.*
@ -51,7 +49,7 @@ internal object AgentCompletionRunner : CompletionRunner {
val jobRef = AtomicReference<Job?>()
val messageBuilder = StringBuilder()
val project = request.callParameters.project!!
val stream = shouldStreamAgentToolLoop(project, request)
val stream = shouldStreamAgentToolLoop(request)
val toolCallHandler = project.let { McpToolCallHandler.getInstance(it) }
val toolRegistry = createChatToolRegistry(
callParameters = request.callParameters,
@ -176,21 +174,13 @@ internal object AgentCompletionRunner : CompletionRunner {
}
internal fun shouldStreamAgentToolLoop(
project: Project,
request: CompletionRunnerRequest.Chat,
): Boolean {
val provider = request.serviceType
return when (provider) {
ServiceType.CUSTOM_OPENAI -> {
val selectedServiceId = project.service<ModelSettings>()
.getStoredModelForFeature(request.callParameters.featureType)
project.service<CustomServicesSettings>().state.services
.firstOrNull { it.id == selectedServiceId }
?.chatCompletionSettings
?.shouldStream()
?: false
}
ServiceType.CUSTOM_OPENAI -> shouldStreamCustomOpenAI(
request.callParameters.featureType
)
ServiceType.GOOGLE -> false
else -> true
}

View file

@ -0,0 +1,231 @@
package ee.carlrobert.codegpt.diagnostics
import com.intellij.codeInsight.daemon.impl.DaemonCodeAnalyzerImpl
import com.intellij.codeInsight.daemon.impl.HighlightInfo
import com.intellij.lang.annotation.HighlightSeverity
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.components.Service
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.project.DumbService
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.text.StringUtil
import com.intellij.openapi.vfs.LocalFileSystem
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.psi.PsiDocumentManager
import com.intellij.psi.PsiManager
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
@Serializable
enum class DiagnosticsFilter(val displayName: String) {
@SerialName("errors_only")
ERRORS_ONLY("Errors only"),
@SerialName("all")
ALL("All");
fun includes(severity: HighlightSeverity): Boolean {
return when (this) {
ERRORS_ONLY -> severity == HighlightSeverity.ERROR
ALL -> severity == HighlightSeverity.ERROR ||
severity == HighlightSeverity.WARNING ||
severity == HighlightSeverity.WEAK_WARNING ||
severity == HighlightSeverity.INFORMATION
}
}
fun minimumSeverity(): HighlightSeverity {
return when (this) {
ERRORS_ONLY -> HighlightSeverity.ERROR
ALL -> HighlightSeverity.INFORMATION
}
}
fun emptyMessage(): String {
return when (this) {
ERRORS_ONLY -> "No errors found."
ALL -> "No diagnostics found."
}
}
}
data class DiagnosticsReport(
val filePath: String,
val filter: DiagnosticsFilter,
val content: String = "",
val diagnosticCount: Int = 0,
val error: String? = null
) {
val hasDiagnostics: Boolean
get() = error == null && diagnosticCount > 0
}
@Service(Service.Level.PROJECT)
class ProjectDiagnosticsService(
private val project: Project
) {
fun findVirtualFile(filePath: String): VirtualFile? {
val normalizedPath = filePath.replace('\\', '/')
val fileSystem = LocalFileSystem.getInstance()
return fileSystem.findFileByPath(normalizedPath)
?: fileSystem.refreshAndFindFileByPath(normalizedPath)
}
fun collect(
virtualFile: VirtualFile,
filter: DiagnosticsFilter = DiagnosticsFilter.ALL
): DiagnosticsReport {
return try {
var result = DiagnosticsReport(
filePath = virtualFile.path,
filter = filter
)
ApplicationManager.getApplication().invokeAndWait {
result = ApplicationManager.getApplication().runWriteAction<DiagnosticsReport> {
DumbService.getInstance(project).runReadActionInSmartMode<DiagnosticsReport> {
val document = FileDocumentManager.getInstance().getDocument(virtualFile)
?: return@runReadActionInSmartMode DiagnosticsReport(
filePath = virtualFile.path,
filter = filter,
error = "No document found for file."
)
PsiDocumentManager.getInstance(project).commitDocument(document)
val psiFile = PsiManager.getInstance(project).findFile(virtualFile)
?: return@runReadActionInSmartMode DiagnosticsReport(
filePath = virtualFile.path,
filter = filter,
error = "No PSI file found for: ${virtualFile.path}"
)
val rangeHighlights = DaemonCodeAnalyzerImpl.getHighlights(
document,
filter.minimumSeverity(),
project
)
val fileLevel = getFileLevelHighlights(psiFile)
val highlights = (rangeHighlights.asSequence() + fileLevel.asSequence())
.filter { filter.includes(it.severity) }
.mapNotNull { highlight ->
extractMessage(highlight)?.let { message ->
DiagnosticEntry(highlight, message)
}
}
.distinctBy { Triple(it.message, it.highlight.startOffset, it.highlight.severity) }
.sortedWith(
compareBy<DiagnosticEntry>(
{ severityOrder(it.highlight.severity) },
{ it.highlight.startOffset.coerceAtLeast(0) }
)
)
.toList()
if (highlights.isEmpty()) {
return@runReadActionInSmartMode DiagnosticsReport(
filePath = virtualFile.path,
filter = filter
)
}
val maxItems = 200
val overflow = (highlights.size - maxItems).coerceAtLeast(0)
val shown = highlights.take(maxItems)
val content = buildString {
append("File: ${virtualFile.name}\n")
append("Path: ${virtualFile.path}\n")
append("Filter: ${filter.displayName}\n\n")
shown.forEach { entry ->
val info = entry.highlight
val startOffset = info.startOffset.coerceIn(0, document.textLength)
val lineColText =
if (info.startOffset >= 0 && document.textLength > 0) {
val line = document.getLineNumber(startOffset) + 1
val col =
startOffset - document.getLineStartOffset(line - 1) + 1
"line $line, col $col"
} else {
"file-level"
}
append("- [${severityLabel(info.severity)}] $lineColText: ${entry.message}\n")
}
if (overflow > 0) {
append("... ($overflow more not shown)\n")
}
}
DiagnosticsReport(
filePath = virtualFile.path,
filter = filter,
content = content,
diagnosticCount = highlights.size
)
}
}
}
result
} catch (e: Exception) {
DiagnosticsReport(
filePath = virtualFile.path,
filter = filter,
error = "Error retrieving diagnostics: ${e.message}"
)
}
}
private fun getFileLevelHighlights(psiFile: com.intellij.psi.PsiFile): List<HighlightInfo> {
return try {
val method = DaemonCodeAnalyzerImpl::class.java.methods.firstOrNull {
it.name == "getFileLevelHighlights" && it.parameterCount == 2
}
if (method != null) {
@Suppress("UNCHECKED_CAST")
method.invoke(null, project, psiFile) as? List<HighlightInfo> ?: emptyList()
} else {
emptyList()
}
} catch (_: Throwable) {
emptyList()
}
}
private fun extractMessage(info: HighlightInfo): String? {
val rawMessage = info.description ?: info.toolTip ?: ""
return StringUtil.removeHtmlTags(rawMessage, false)
.trim()
.takeIf { it.isNotBlank() }
}
private fun severityLabel(severity: HighlightSeverity): String {
return when (severity) {
HighlightSeverity.ERROR -> "ERROR"
HighlightSeverity.WARNING -> "WARNING"
HighlightSeverity.WEAK_WARNING -> "WEAK_WARNING"
HighlightSeverity.INFORMATION -> "INFO"
else -> severity.toString()
}
}
private fun severityOrder(severity: HighlightSeverity): Int {
return when (severity) {
HighlightSeverity.ERROR -> 0
HighlightSeverity.WARNING -> 1
HighlightSeverity.WEAK_WARNING -> 2
HighlightSeverity.INFORMATION -> 3
else -> 4
}
}
private data class DiagnosticEntry(
val highlight: HighlightInfo,
val message: String
)
}

View file

@ -795,13 +795,18 @@ class InlineEditInlay(private var editor: Editor) : Disposable {
private fun collectDiagnosticsInfo(): String? {
val tags: Set<TagDetails> = tagManager.getTags()
val diagnosticsTag =
tags.firstOrNull { it.selected && it is DiagnosticsTagDetails } as? DiagnosticsTagDetails
?: return null
val diagnosticsTags = tags
.filter { it.selected && it is DiagnosticsTagDetails }
.filterIsInstance<DiagnosticsTagDetails>()
if (diagnosticsTags.isEmpty()) {
return null
}
val processor = TagProcessorFactory.getProcessor(project, diagnosticsTag)
val stringBuilder = StringBuilder()
processor.process(Message("", ""), stringBuilder)
diagnosticsTags.forEach { diagnosticsTag ->
val processor = TagProcessorFactory.getProcessor(project, diagnosticsTag)
processor.process(Message("", ""), stringBuilder)
}
return stringBuilder.toString().takeIf { it.isNotBlank() }
}
}

View file

@ -327,6 +327,8 @@ class CustomServiceForm(
.setPreferredSize(Dimension(220, 0))
.setAddAction { handleAddAction() }
.setRemoveAction { handleRemoveAction() }
.setMoveUpAction { handleMoveAction(-1) }
.setMoveDownAction { handleMoveAction(1) }
.setRemoveActionUpdater {
formState.value.services.size > 1
}
@ -347,7 +349,6 @@ class CustomServiceForm(
handleDuplicateAction()
}
})
.disableUpDownActions()
private fun handleRemoveAction() {
val prevSelectedIndex = customProvidersJBList.selectedIndex
@ -373,6 +374,35 @@ class CustomServiceForm(
lastSelectedIndex = -1
}
private fun handleMoveAction(offset: Int) {
val selectedIndex = customProvidersJBList.selectedIndex
if (selectedIndex == -1) {
return
}
val targetIndex = selectedIndex + offset
if (targetIndex !in formState.value.services.indices) {
return
}
if (lastSelectedIndex != -1 && lastSelectedIndex < formState.value.services.size) {
updateStateFromForm(lastSelectedIndex)
}
val movedId = formState.value.services[selectedIndex].id
selectedServiceId = movedId
pendingSelectedId = movedId
formState.update { state ->
val reordered = state.services.toMutableList()
val moved = reordered.removeAt(selectedIndex)
reordered.add(targetIndex, moved)
state.copy(services = reordered.toImmutableList())
}
lastSelectedIndex = -1
}
private fun handleDuplicateAction() {
if (lastSelectedIndex >= 0 && lastSelectedIndex < formState.value.services.size) {
updateStateFromForm(lastSelectedIndex)

View file

@ -8,7 +8,6 @@ import ee.carlrobert.codegpt.settings.service.FeatureType
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
import ee.carlrobert.codegpt.ui.textarea.lookup.LookupActionItem
import ee.carlrobert.codegpt.ui.textarea.lookup.LookupGroupItem
import ee.carlrobert.codegpt.ui.textarea.lookup.action.DiagnosticsActionItem
import ee.carlrobert.codegpt.ui.textarea.lookup.action.ImageActionItem
import ee.carlrobert.codegpt.ui.textarea.lookup.action.WebActionItem
import ee.carlrobert.codegpt.ui.textarea.lookup.action.files.IncludeOpenFilesActionItem
@ -44,7 +43,7 @@ class SearchManager(
FoldersGroupItem(project, tagManager),
if (GitFeatureAvailability.isAvailable) GitGroupItem(project) else null,
HistoryGroupItem(),
DiagnosticsActionItem(tagManager)
DiagnosticsGroupItem(tagManager)
).filter { it.enabled }
private fun getAgentGroups() = listOfNotNull(
@ -52,6 +51,7 @@ class SearchManager(
FoldersGroupItem(project, tagManager),
if (GitFeatureAvailability.isAvailable) GitGroupItem(project) else null,
MCPGroupItem(tagManager, FeatureType.AGENT),
DiagnosticsGroupItem(tagManager),
ImageActionItem(project, tagManager)
).filter { it.enabled }
@ -62,7 +62,7 @@ class SearchManager(
HistoryGroupItem(),
PersonasGroupItem(tagManager),
MCPGroupItem(tagManager, featureType ?: FeatureType.CHAT),
DiagnosticsActionItem(tagManager),
DiagnosticsGroupItem(tagManager),
WebActionItem(tagManager),
ImageActionItem(project, tagManager)
).filter { it.enabled }

View file

@ -1,24 +1,16 @@
package ee.carlrobert.codegpt.ui.textarea
import com.intellij.codeInsight.daemon.impl.DaemonCodeAnalyzerImpl
import com.intellij.codeInsight.daemon.impl.HighlightInfo
import com.intellij.lang.annotation.HighlightSeverity
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.components.service
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.project.DumbService
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.text.StringUtil
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.psi.PsiDocumentManager
import com.intellij.psi.PsiManager
import ee.carlrobert.codegpt.EncodingManager
import ee.carlrobert.codegpt.completions.CompletionRequestUtil
import ee.carlrobert.codegpt.conversations.Conversation
import ee.carlrobert.codegpt.conversations.ConversationsState
import ee.carlrobert.codegpt.conversations.message.Message
import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
import ee.carlrobert.codegpt.ui.textarea.header.tag.*
import ee.carlrobert.codegpt.ui.textarea.lookup.action.HistoryActionItem
@ -270,118 +262,17 @@ class DiagnosticsTagProcessor(
private val project: Project,
private val tagDetails: DiagnosticsTagDetails,
) : TagProcessor {
private val diagnosticsService = project.service<ProjectDiagnosticsService>()
override fun process(message: Message, promptBuilder: StringBuilder) {
val diagnostics = diagnosticsService.collect(tagDetails.virtualFile, tagDetails.filter)
if (diagnostics.content.isBlank() && diagnostics.error == null) {
return
}
promptBuilder
.append("\n## Current File Problems\n")
.append(getDiagnosticsString(project, tagDetails.virtualFile))
.append("\n## ${tagDetails.virtualFile.name} Problems (${tagDetails.filter.displayName})\n")
.append(diagnostics.error ?: diagnostics.content)
.append("\n")
}
private fun getDiagnosticsString(project: Project, virtualFile: VirtualFile): String {
return try {
var result = ""
ApplicationManager.getApplication().invokeAndWait {
result = ApplicationManager.getApplication().runWriteAction<String> {
DumbService.getInstance(project).runReadActionInSmartMode<String> {
val document = FileDocumentManager.getInstance().getDocument(virtualFile)
?: return@runReadActionInSmartMode "No document found for file"
PsiDocumentManager.getInstance(project).commitDocument(document)
val psiManager = PsiManager.getInstance(project)
val psiFile = psiManager.findFile(virtualFile)
?: return@runReadActionInSmartMode "No PSI file found for: ${virtualFile.path}"
val rangeHighlights =
DaemonCodeAnalyzerImpl.getHighlights(
document,
HighlightSeverity.WEAK_WARNING,
project
)
// TODO: Find a better solution
val fileLevel: List<HighlightInfo> = try {
val method = DaemonCodeAnalyzerImpl::class.java.methods.firstOrNull {
it.name == "getFileLevelHighlights" && it.parameterCount == 2
}
if (method != null) {
@Suppress("UNCHECKED_CAST")
method.invoke(null, project, psiFile) as? List<HighlightInfo>
?: emptyList()
} else {
emptyList()
}
} catch (_: Throwable) {
emptyList()
}
val highlights = (rangeHighlights.asSequence() + fileLevel.asSequence())
.distinctBy { Triple(it.description, it.startOffset, it.severity) }
.sortedWith(
compareBy<HighlightInfo>(
{ severityOrder(it.severity) },
{ it.startOffset.coerceAtLeast(0) }
)
)
.toList()
if (highlights.isEmpty()) {
return@runReadActionInSmartMode ""
}
val maxItems = 200
val overflow = (highlights.size - maxItems).coerceAtLeast(0)
val shown = highlights.take(maxItems)
buildString {
append("File: ${virtualFile.name}\n")
append("Path: ${virtualFile.path}\n\n")
shown.forEach { info ->
val startOffset = info.startOffset.coerceIn(0, document.textLength)
val lineColText =
if (info.startOffset >= 0 && document.textLength > 0) {
val line = document.getLineNumber(startOffset) + 1
val col =
startOffset - document.getLineStartOffset(line - 1) + 1
"line $line, col $col"
} else {
"file-level"
}
val rawMessage = info.description ?: info.toolTip ?: ""
val message = StringUtil.removeHtmlTags(rawMessage, false).trim()
val severityLabel = when (info.severity) {
HighlightSeverity.ERROR -> "ERROR"
HighlightSeverity.WARNING -> "WARNING"
HighlightSeverity.WEAK_WARNING -> "WEAK_WARNING"
HighlightSeverity.INFORMATION -> "INFO"
else -> info.severity.toString()
}
append("- [$severityLabel] $lineColText: $message\n")
}
if (overflow > 0) {
append("... ($overflow more not shown)\n")
}
}
}
}
}
result
} catch (e: Exception) {
"Error retrieving diagnostics: ${e.message}"
}
}
private fun severityOrder(severity: HighlightSeverity): Int {
return when (severity) {
HighlightSeverity.ERROR -> 0
HighlightSeverity.WARNING -> 1
HighlightSeverity.WEAK_WARNING -> 2
HighlightSeverity.INFORMATION -> 3
else -> 4
}
}
}

View file

@ -5,6 +5,7 @@ import com.intellij.openapi.editor.SelectionModel
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.ui.JBColor
import ee.carlrobert.codegpt.Icons
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
import ee.carlrobert.codegpt.mcp.ConnectionStatus
import ee.carlrobert.codegpt.mcp.McpResource
import ee.carlrobert.codegpt.mcp.McpTool
@ -304,7 +305,9 @@ class CodeAnalyzeTagDetails : TagDetails("Code Analyze", AllIcons.Actions.Depend
override fun getTooltipText(): String? = null
}
data class DiagnosticsTagDetails(val virtualFile: VirtualFile) :
TagDetails("${virtualFile.name} Problems", AllIcons.General.InspectionsEye) {
override fun getTooltipText(): String = virtualFile.path
data class DiagnosticsTagDetails(
val virtualFile: VirtualFile,
val filter: DiagnosticsFilter = DiagnosticsFilter.ALL
) : TagDetails("${virtualFile.name} Problems (${filter.displayName})", AllIcons.General.InspectionsEye) {
override fun getTooltipText(): String = "${virtualFile.path} (${filter.displayName})"
}

View file

@ -1,41 +0,0 @@
package ee.carlrobert.codegpt.ui.textarea.lookup.action
import com.intellij.icons.AllIcons
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.VirtualFile
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
import ee.carlrobert.codegpt.util.EditorUtil
class DiagnosticsActionItem(
private val tagManager: TagManager
) : AbstractLookupActionItem() {
override val displayName: String = "Diagnostics"
override val icon = AllIcons.General.InspectionsEye
override val enabled: Boolean
get() = tagManager.getTags().none { it is DiagnosticsTagDetails } &&
tagManager.getTags().any { it is FileTagDetails || it is EditorTagDetails }
override fun execute(project: Project, userInputPanel: UserInputPanel) {
val virtualFile = findVirtualFile(project)
virtualFile?.let { file ->
userInputPanel.addTag(DiagnosticsTagDetails(file))
}
}
private fun findVirtualFile(project: Project): VirtualFile? {
val existingFile = tagManager.getTags()
.firstNotNullOfOrNull { tag ->
when (tag) {
is FileTagDetails -> tag.virtualFile
is EditorTagDetails -> tag.virtualFile
else -> null
}
}
return existingFile ?: EditorUtil.getSelectedEditor(project)?.virtualFile
}
}

View file

@ -0,0 +1,24 @@
package ee.carlrobert.codegpt.ui.textarea.lookup.action
import com.intellij.openapi.vfs.VirtualFile
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorSelectionTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.SelectionTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagDetails
internal fun selectedContextFiles(tags: Collection<TagDetails>): List<VirtualFile> {
return tags.asSequence()
.filter { it.selected }
.mapNotNull { tag ->
when (tag) {
is FileTagDetails -> tag.virtualFile
is EditorTagDetails -> tag.virtualFile
is SelectionTagDetails -> tag.virtualFile
is EditorSelectionTagDetails -> tag.virtualFile
else -> null
}
}
.distinctBy { it.path }
.toList()
}

View file

@ -0,0 +1,57 @@
package ee.carlrobert.codegpt.ui.textarea.lookup.action
import com.intellij.icons.AllIcons
import com.intellij.notification.NotificationType
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService
import ee.carlrobert.codegpt.ui.OverlayUtil
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
class DiagnosticsFilterActionItem(
private val tagManager: TagManager,
private val filter: DiagnosticsFilter
) : AbstractLookupActionItem() {
override val displayName: String = filter.displayName
override val icon = AllIcons.General.InspectionsEye
override fun execute(project: Project, userInputPanel: UserInputPanel) {
val diagnosticsService = project.service<ProjectDiagnosticsService>()
val files = selectedContextFiles(userInputPanel.getSelectedTags())
var matched = false
files.forEach { virtualFile ->
val diagnostics = diagnosticsService.collect(virtualFile, filter)
if (!diagnostics.hasDiagnostics) {
return@forEach
}
matched = true
val newTag = DiagnosticsTagDetails(virtualFile, filter)
val existing = tagManager.getTags()
.filterIsInstance<DiagnosticsTagDetails>()
.firstOrNull { it.virtualFile == virtualFile }
when {
existing == null -> {
userInputPanel.addTag(newTag)
}
existing != newTag -> {
tagManager.updateTag(existing, newTag)
}
}
}
if (!matched) {
OverlayUtil.showNotification(
filter.emptyMessage().removeSuffix(".") + " in selected context files.",
NotificationType.INFORMATION
)
}
}
}

View file

@ -0,0 +1,27 @@
package ee.carlrobert.codegpt.ui.textarea.lookup.group
import com.intellij.icons.AllIcons
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
import ee.carlrobert.codegpt.ui.textarea.lookup.LookupActionItem
import ee.carlrobert.codegpt.ui.textarea.lookup.action.DiagnosticsFilterActionItem
import ee.carlrobert.codegpt.ui.textarea.lookup.action.selectedContextFiles
class DiagnosticsGroupItem(
private val tagManager: TagManager
) : AbstractLookupGroupItem() {
override val displayName: String = "Diagnostics"
override val icon = AllIcons.General.InspectionsEye
override val enabled: Boolean
get() = selectedContextFiles(tagManager.getTags()).isNotEmpty()
override suspend fun getLookupItems(searchText: String): List<LookupActionItem> {
return listOf(
DiagnosticsFilterActionItem(tagManager, DiagnosticsFilter.ERRORS_ONLY),
DiagnosticsFilterActionItem(tagManager, DiagnosticsFilter.ALL)
).filter {
searchText.isEmpty() || it.displayName.contains(searchText, ignoreCase = true)
}
}
}

View file

@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.*
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI
import ee.carlrobert.codegpt.settings.models.ModelCatalog
import ee.carlrobert.codegpt.settings.models.ModelSettings
import ee.carlrobert.codegpt.settings.service.FeatureType
@ -174,6 +175,144 @@ class AgentProviderIntegrationTest : IntegrationTest() {
assertThat(result.events.text.toString()).isEqualTo("Hello from Custom OpenAI")
}
fun testCustomOpenAIAgentStreamsWhenStoredSelectionUsesModelId() {
val customService = configureCustomOpenAIService(stream = true)
service<ModelSettings>().setModel(
FeatureType.AGENT,
"custom-agent-model",
ServiceType.CUSTOM_OPENAI
)
assertThat(shouldStreamCustomOpenAI(FeatureType.AGENT)).isTrue()
expectCustomOpenAI(StreamHttpExchange { request ->
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
assertThat(request.method).isEqualTo("POST")
assertThat(request.body["stream"]).isEqualTo(true)
assertThat(extractPromptText(request)).contains("Say hello from streamed Custom OpenAI")
chatCompletionChunks("custom-agent-model", "Hello from streamed Custom OpenAI")
})
val result = runAgent(ServiceType.CUSTOM_OPENAI, "Say hello from streamed Custom OpenAI")
assertThat(customService.id).isNotBlank()
assertThat(result.output).isEqualTo("Hello from streamed Custom OpenAI")
assertThat(result.events.text.toString()).isEqualTo("Hello from streamed Custom OpenAI")
}
fun testCustomOpenAIResponsesAgentStreamsWhenStoredSelectionUsesModelId() {
val customService = configureCustomOpenAIService(
path = "/v1/responses",
model = "custom-responses-model",
stream = true,
useResponsesApiBody = true
)
service<ModelSettings>().setModel(
FeatureType.AGENT,
"custom-responses-model",
ServiceType.CUSTOM_OPENAI
)
assertThat(shouldStreamCustomOpenAI(FeatureType.AGENT)).isTrue()
expectCustomOpenAI(StreamHttpExchange { request ->
assertThat(request.uri.path).isEqualTo("/v1/responses")
assertThat(request.method).isEqualTo("POST")
assertThat(request.body["stream"]).isEqualTo(true)
assertThat(extractPromptText(request)).contains("Say hello from Custom OpenAI Responses")
openAiResponsesChunks(
model = "custom-responses-model",
text = "Hello from Custom OpenAI Responses"
)
})
val result = runAgent(ServiceType.CUSTOM_OPENAI, "Say hello from Custom OpenAI Responses")
assertThat(customService.id).isNotBlank()
assertThat(result.output).isEqualTo("Hello from Custom OpenAI Responses")
assertThat(result.events.text.toString()).isEqualTo("Hello from Custom OpenAI Responses")
}
fun testCustomOpenAIResponsesAgentSerializesToolHistoryAsResponsesInput() {
val fixture = createReadFixture("Custom Responses fixture")
configureCustomOpenAIService(
path = "/v1/responses",
model = "custom-responses-model",
stream = false,
useResponsesApiBody = true
)
expectCustomOpenAI(BasicHttpExchange { request ->
assertThat(request.uri.path).isEqualTo("/v1/responses")
assertThat(request.method).isEqualTo("POST")
assertThat(extractPromptText(request)).contains("Read the fixture and repeat its contents")
ResponseEntity(
openAiResponsesToolCallResponse(
model = "custom-responses-model",
toolName = "Read",
callId = "call_custom_responses_read",
arguments = """{"file_path":"${fixture.path}"}"""
)
)
})
expectCustomOpenAI(BasicHttpExchange { request ->
assertThat(request.uri.path).isEqualTo("/v1/responses")
assertThat(request.method).isEqualTo("POST")
assertThatResponsesToolHistory(
request = request,
toolName = "Read",
callId = "call_custom_responses_read",
fixture = fixture
)
ResponseEntity(
openAiResponsesResponse(
model = "custom-responses-model",
text = "Custom Responses read: ${fixture.contents}"
)
)
})
val result = runAgent(ServiceType.CUSTOM_OPENAI, "Read the fixture and repeat its contents")
assertThat(result.output).isEqualTo("Custom Responses read: ${fixture.contents}")
assertThat(result.events.text.toString()).isEqualTo("Custom Responses read: ${fixture.contents}")
}
fun testCustomOpenAIResponsesStreamingAgentCompletesToolLoop() {
val fixture = createReadFixture("Custom Responses streaming fixture")
configureCustomOpenAIService(
path = "/v1/responses",
model = "custom-responses-model",
stream = true,
useResponsesApiBody = true
)
expectCustomOpenAI(StreamHttpExchange { request ->
assertThat(request.uri.path).isEqualTo("/v1/responses")
assertThat(request.method).isEqualTo("POST")
assertThat(extractPromptText(request)).contains("Read the fixture and repeat its contents")
openAiResponsesToolCallChunks(
model = "custom-responses-model",
toolName = "Read",
callId = "call_custom_responses_stream_read",
arguments = """{"file_path":"${fixture.path}"}"""
)
})
expectCustomOpenAI(StreamHttpExchange { request ->
assertThat(request.uri.path).isEqualTo("/v1/responses")
assertThat(request.method).isEqualTo("POST")
assertThatResponsesToolHistory(
request = request,
toolName = "Read",
callId = "call_custom_responses_stream_read",
fixture = fixture
)
openAiResponsesChunks(
model = "custom-responses-model",
text = "Custom Responses streaming read: ${fixture.contents}"
)
})
val result = runAgent(ServiceType.CUSTOM_OPENAI, "Read the fixture and repeat its contents")
assertThat(result.output).isEqualTo("Custom Responses streaming read: ${fixture.contents}")
assertThat(result.events.text.toString()).isEqualTo("Custom Responses streaming read: ${fixture.contents}")
}
private fun runAgent(
provider: ServiceType,
userMessage: String
@ -246,9 +385,10 @@ class AgentProviderIntegrationTest : IntegrationTest() {
)
}
private fun openAiResponsesChunks(): List<String> {
val model = "gpt-4.1-mini"
val text = "Hello from OpenAI"
private fun openAiResponsesChunks(
model: String = "gpt-4.1-mini",
text: String = "Hello from OpenAI"
): List<String> {
val chunks = text.chunked(4)
return chunks.mapIndexed { index, chunk ->
jsonMapResponse(
@ -313,6 +453,154 @@ class AgentProviderIntegrationTest : IntegrationTest() {
)
}
private fun openAiResponsesToolCallChunks(
model: String,
toolName: String,
callId: String,
arguments: String
): List<String> {
return listOf(
jsonMapResponse(
e("type", "response.output_item.added"),
e(
"item",
jsonMap(
e("type", "function_call"),
e("id", "fc_1"),
e("call_id", callId),
e("name", toolName),
e("arguments", "")
)
),
e("output_index", 0),
e("sequence_number", 1)
),
jsonMapResponse(
e("type", "response.function_call_arguments.delta"),
e("item_id", "fc_1"),
e("output_index", 0),
e("delta", arguments),
e("call_id", callId),
e("sequence_number", 2)
),
jsonMapResponse(
e("type", "response.completed"),
e(
"response",
jsonMap(
e("id", "resp-tool-call"),
e("object", "response"),
e("created_at", 1),
e("model", model),
e(
"output",
jsonArray(
jsonMap(
e("type", "function_call"),
e("id", "fc_1"),
e("call_id", callId),
e("name", toolName),
e("arguments", arguments),
e("status", "completed")
)
)
),
e("parallel_tool_calls", true),
e("status", "completed"),
e("text", jsonMap())
)
),
e("sequence_number", 3)
)
)
}
private fun openAiResponsesToolCallResponse(
model: String,
toolName: String,
callId: String,
arguments: String
): String {
return jsonMapResponse(
e("id", "resp-tool-call"),
e("object", "response"),
e("created_at", 1),
e("model", model),
e(
"output",
jsonArray(
jsonMap(
e("type", "function_call"),
e("id", "fc_1"),
e("call_id", callId),
e("name", toolName),
e("arguments", arguments),
e("status", "completed")
)
)
),
e("parallel_tool_calls", true),
e("status", "completed"),
e("text", jsonMap()),
e(
"usage",
jsonMap(
e("input_tokens", 1),
e("input_tokens_details", jsonMap("cached_tokens", 0)),
e("output_tokens", 1),
e("output_tokens_details", jsonMap("reasoning_tokens", 0)),
e("total_tokens", 2)
)
)
)
}
private fun openAiResponsesResponse(
model: String,
text: String
): String {
return jsonMapResponse(
e("id", "resp-openai-test"),
e("object", "response"),
e("created_at", 1),
e("model", model),
e(
"output",
jsonArray(
jsonMap(
e("type", "message"),
e("id", "msg_1"),
e("role", "assistant"),
e("status", "completed"),
e(
"content",
jsonArray(
jsonMap(
e("type", "output_text"),
e("text", text),
e("annotations", jsonArray())
)
)
)
)
)
),
e("parallel_tool_calls", true),
e("status", "completed"),
e("text", jsonMap()),
e(
"usage",
jsonMap(
e("input_tokens", 1),
e("input_tokens_details", jsonMap("cached_tokens", 0)),
e("output_tokens", 1),
e("output_tokens_details", jsonMap("reasoning_tokens", 0)),
e("total_tokens", 2)
)
)
)
}
private fun anthropicTextChunks(text: String): List<String> {
return listOf(
jsonMapResponse(
@ -573,15 +861,24 @@ class AgentProviderIntegrationTest : IntegrationTest() {
)
}
private fun configureCustomOpenAIService(): CustomServiceSettingsState {
private fun configureCustomOpenAIService(
path: String = "/v1/chat/completions",
model: String = "custom-agent-model",
stream: Boolean = false,
useResponsesApiBody: Boolean = false
): CustomServiceSettingsState {
val settings = service<CustomServicesSettings>()
val serviceState = CustomServiceSettingsState().apply {
name = "Agent Test Custom Service"
chatCompletionSettings.url =
System.getProperty("customOpenAI.baseUrl") + "/v1/chat/completions"
System.getProperty("customOpenAI.baseUrl") + path
chatCompletionSettings.headers.clear()
chatCompletionSettings.body.clear()
chatCompletionSettings.body["model"] = "custom-agent-model"
chatCompletionSettings.body["model"] = model
chatCompletionSettings.body["stream"] = stream
if (useResponsesApiBody) {
chatCompletionSettings.body["input"] = "\$OPENAI_MESSAGES"
}
}
settings.state.services.clear()
settings.state.services.add(serviceState)
@ -616,6 +913,28 @@ class AgentProviderIntegrationTest : IntegrationTest() {
}
}
private fun assertThatResponsesToolHistory(
request: RequestEntity,
toolName: String,
callId: String,
fixture: ReadFixture
) {
val input = request.body["input"] as? List<*> ?: error("Expected responses input list")
val items = input.mapNotNull { it as? Map<*, *> }
assertThat(items.any { item ->
item["type"] == "function_call" &&
item["name"] == toolName &&
item["call_id"] == callId &&
item["arguments"] == """{"file_path":"${fixture.path}"}"""
}).isTrue()
assertThat(items.any { item ->
item["type"] == "function_call_output" &&
item["call_id"] == callId &&
(item["output"] as? String).orEmpty().contains(fixture.contents)
}).isTrue()
}
private fun extractGooglePromptText(request: RequestEntity): String {
val contents = request.body["contents"] as? List<*> ?: return ""
return contents.joinToString("\n") { content ->

View file

@ -0,0 +1,26 @@
package ee.carlrobert.codegpt.agent
import ee.carlrobert.codegpt.agent.tools.DiagnosticsTool
import ee.carlrobert.codegpt.settings.hooks.HookManager
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
import java.io.File
class DiagnosticsToolTest : IntegrationTest() {
fun `test diagnostics tool should be registered in tool specs`() {
assertThat(ToolSpecs.find("Diagnostics")).isNotNull()
}
fun `test diagnostics tool should return missing file error`() {
val missingPath = File(project.basePath, "does-not-exist.txt").absolutePath
val tool = DiagnosticsTool(project, "test-session-id", HookManager(project))
val result = runBlocking {
tool.execute(DiagnosticsTool.Args(missingPath))
}
assertThat(result.error).isEqualTo("File not found: $missingPath")
}
}

View file

@ -0,0 +1,219 @@
package ee.carlrobert.codegpt.agent.clients
import ai.koog.prompt.executor.clients.openai.base.models.Content
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIFunction
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolCall
import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolFunction
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.params.LLMParams
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.buildJsonArray
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.boolean
import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlinx.serialization.json.put
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class CustomOpenAIResponsesApiSerializationTest {
private val json = Json {
ignoreUnknownKeys = true
encodeDefaults = true
explicitNulls = false
}
@Test
fun `responses api request should encode tools with top level name`() {
val state = CustomServiceChatCompletionSettingsState().apply {
url = "https://example.com/v1/responses"
body.clear()
body["stream"] = true
body["input"] = "\$OPENAI_MESSAGES"
}
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
val serializeMethod = client.javaClass.getDeclaredMethod(
"serializeProviderChatRequest",
List::class.java,
LLModel::class.java,
List::class.java,
OpenAIToolChoice::class.java,
LLMParams::class.java,
Boolean::class.javaPrimitiveType
)
serializeMethod.isAccessible = true
val payload = serializeMethod.invoke(
client,
listOf(OpenAIMessage.User(content = Content.Text("hello"))),
LLModel(
id = "gpt-test",
provider = CustomOpenAILLMClient.CustomOpenAI,
capabilities = emptyList(),
contextLength = 128_000,
maxOutputTokens = 4_096
),
listOf(
OpenAITool(
OpenAIToolFunction(
name = "Diagnostics",
description = "Read IDE diagnostics",
parameters = buildJsonObject {
put("type", "object")
put("properties", buildJsonObject {})
put("required", buildJsonArray {})
},
strict = true
)
)
),
OpenAIToolChoice.Function(OpenAIToolChoice.FunctionName("Diagnostics")),
CustomOpenAIParams(),
true
) as String
val request = json.parseToJsonElement(payload).jsonObject
val tool = request.getValue("tools").jsonArray.single().jsonObject
val toolChoice = request.getValue("tool_choice").jsonObject
assertThat(tool.getValue("type").jsonPrimitive.content).isEqualTo("function")
assertThat(tool.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics")
assertThat(tool.getValue("description").jsonPrimitive.content).isEqualTo("Read IDE diagnostics")
assertThat(tool.getValue("strict").jsonPrimitive.boolean).isTrue()
assertThat(tool).doesNotContainKey("function")
assertThat(toolChoice.getValue("type").jsonPrimitive.content).isEqualTo("function")
assertThat(toolChoice.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics")
}
@Test
fun `responses api request should encode mode tool choice as lowercase string`() {
val state = CustomServiceChatCompletionSettingsState().apply {
url = "https://example.com/v1/responses"
body.clear()
body["stream"] = true
body["input"] = "\$OPENAI_MESSAGES"
}
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
val serializeMethod = client.javaClass.getDeclaredMethod(
"serializeProviderChatRequest",
List::class.java,
LLModel::class.java,
List::class.java,
OpenAIToolChoice::class.java,
LLMParams::class.java,
Boolean::class.javaPrimitiveType
)
serializeMethod.isAccessible = true
val payload = serializeMethod.invoke(
client,
listOf(OpenAIMessage.User(content = Content.Text("hello"))),
LLModel(
id = "gpt-test",
provider = CustomOpenAILLMClient.CustomOpenAI,
capabilities = emptyList(),
contextLength = 128_000,
maxOutputTokens = 4_096
),
emptyList<OpenAITool>(),
openAIToolChoiceMode("required"),
CustomOpenAIParams(),
true
) as String
val request = json.parseToJsonElement(payload).jsonObject
assertThat(request.getValue("tool_choice").jsonPrimitive.content).isEqualTo("required")
}
@Test
fun `responses api request should encode agent tool history as input items`() {
val state = CustomServiceChatCompletionSettingsState().apply {
url = "https://example.com/v1/responses"
body.clear()
body["stream"] = true
body["input"] = "\$OPENAI_MESSAGES"
}
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
val serializeMethod = client.javaClass.getDeclaredMethod(
"serializeProviderChatRequest",
List::class.java,
LLModel::class.java,
List::class.java,
OpenAIToolChoice::class.java,
LLMParams::class.java,
Boolean::class.javaPrimitiveType
)
serializeMethod.isAccessible = true
val payload = serializeMethod.invoke(
client,
listOf(
OpenAIMessage.System(content = Content.Text("System instructions")),
OpenAIMessage.User(content = Content.Text("Read the fixture")),
OpenAIMessage.Assistant(
content = Content.Text("Calling Read"),
toolCalls = listOf(
OpenAIToolCall(
"call_read",
OpenAIFunction("Read", "\"{\\\"file_path\\\":\\\"/tmp/fixture.txt\\\"}\"")
)
)
),
OpenAIMessage.Tool(
content = Content.Text("1\tfixture contents"),
toolCallId = "call_read"
)
),
LLModel(
id = "gpt-test",
provider = CustomOpenAILLMClient.CustomOpenAI,
capabilities = emptyList(),
contextLength = 128_000,
maxOutputTokens = 4_096
),
emptyList<OpenAITool>(),
null,
CustomOpenAIParams(),
true
) as String
val request = json.parseToJsonElement(payload).jsonObject
val input = request.getValue("input").jsonArray
assertThat(input.map { it.jsonObject.getValue("type").jsonPrimitive.content }).containsExactly(
"message",
"message",
"message",
"function_call",
"function_call_output"
)
assertThat(input[0].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("developer")
assertThat(input[1].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("user")
assertThat(input[2].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("assistant")
assertThat(
input[2].jsonObject.getValue("content").jsonArray.single().jsonObject.getValue("text").jsonPrimitive.content
).isEqualTo("Calling Read")
assertThat(input[3].jsonObject.getValue("name").jsonPrimitive.content).isEqualTo("Read")
assertThat(input[3].jsonObject.getValue("call_id").jsonPrimitive.content).isEqualTo("call_read")
assertThat(input[3].jsonObject.getValue("arguments").jsonPrimitive.content)
.isEqualTo("""{"file_path":"/tmp/fixture.txt"}""")
assertThat(input[4].jsonObject.getValue("call_id").jsonPrimitive.content).isEqualTo("call_read")
assertThat(input[4].jsonObject.getValue("output").jsonPrimitive.content)
.isEqualTo("1\tfixture contents")
}
private fun openAIToolChoiceMode(value: String): OpenAIToolChoice {
val modeClass = Class.forName(
"ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice\$Mode"
)
val boxMethod = modeClass.getDeclaredMethod("box-impl", String::class.java)
return boxMethod.invoke(null, value) as OpenAIToolChoice
}
}

View file

@ -0,0 +1,94 @@
package ee.carlrobert.codegpt.agent.clients
import ai.koog.prompt.message.Message
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.streaming.toMessageResponses
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class StreamingPayloadNormalizerTest {
@Test
fun `should unwrap sse data payload`() {
assertThat(normalizeSsePayload("data: {\"id\":\"chunk-1\"}"))
.isEqualTo("{\"id\":\"chunk-1\"}")
}
@Test
fun `should ignore done sentinel`() {
assertThat(normalizeSsePayload("data: [DONE]")).isNull()
}
@Test
fun `should unwrap multiline sse event frames`() {
assertThat(
normalizeSsePayload("event: response.created\ndata: {\"type\":\"response.created\"}\n")
).isEqualTo("{\"type\":\"response.created\"}")
}
@Test
fun `should unwrap compact sse frames where event and data share a line`() {
assertThat(
normalizeSsePayload("event: response.created data: {\"type\":\"response.created\"}")
).isEqualTo("{\"type\":\"response.created\"}")
}
@Test
fun `custom openai client should decode sse wrapped chat completion chunks`() {
val state = CustomServiceChatCompletionSettingsState().apply {
url = "https://example.com/v1/chat/completions"
body["stream"] = true
}
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
val decodeMethod = client.javaClass.getDeclaredMethod("decodeStreamingResponse", String::class.java)
decodeMethod.isAccessible = true
val response = decodeMethod.invoke(
client,
"data: {\"choices\":[],\"created\":0,\"id\":\"chunk-1\",\"model\":\"test-model\"}"
) as CustomOpenAIChatCompletionStreamResponse
assertThat(response.id).isEqualTo("chunk-1")
assertThat(response.model).isEqualTo("test-model")
assertThat(response.choices).isEmpty()
}
@Test
fun `responses api streaming tool call should collapse into one named tool call`() {
val state = CustomServiceChatCompletionSettingsState().apply {
url = "https://example.com/v1/responses"
body.clear()
body["stream"] = true
body["input"] = "\$OPENAI_MESSAGES"
}
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
val decodeMethod = client.javaClass.getDeclaredMethod("decodeStreamingResponse", String::class.java)
decodeMethod.isAccessible = true
val processMethod = client.javaClass.getDeclaredMethod("processStreamingResponse", Flow::class.java)
processMethod.isAccessible = true
val chunks = listOf(
"""{"type":"response.output_item.added","item":{"type":"function_call","id":"fc_1","call_id":"call_diag","name":"Diagnostics","arguments":""},"output_index":0,"sequence_number":1}""",
"""{"type":"response.function_call_arguments.delta","item_id":"fc_1","output_index":0,"delta":"{\"file_path\":\"/tmp/mainwindow.cpp\",\"filter\":\"all\"}","call_id":"call_diag","sequence_number":2}""",
"""{"type":"response.completed","response":{"id":"resp-tool","object":"response","created_at":1,"model":"test-model","output":[{"type":"function_call","id":"fc_1","call_id":"call_diag","name":"Diagnostics","arguments":"{\"file_path\":\"/tmp/mainwindow.cpp\",\"filter\":\"all\"}","status":"completed"}],"parallel_tool_calls":true,"status":"completed","text":{}},"sequence_number":3}"""
).map { payload ->
decodeMethod.invoke(client, payload) as CustomOpenAIChatCompletionStreamResponse
}
val responses = runBlocking {
@Suppress("UNCHECKED_CAST")
val frames = processMethod.invoke(client, chunks.asFlow()) as Flow<StreamFrame>
frames.toList().toMessageResponses()
}
val toolCall = responses.filterIsInstance<Message.Tool.Call>().single()
assertThat(toolCall.id).isEqualTo("call_diag")
assertThat(toolCall.tool).isEqualTo("Diagnostics")
assertThat(toolCall.content).isEqualTo("""{"file_path":"/tmp/mainwindow.cpp","filter":"all"}""")
}
}

View file

@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.agent.strategy
import ai.koog.prompt.llm.LLMProvider
import ai.koog.prompt.message.Message
import ai.koog.prompt.message.ResponseMetaInfo
import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson
import org.assertj.core.api.Assertions.assertThat
import kotlin.test.Test
@ -94,4 +95,11 @@ class SingleRunStrategyProviderTest {
assertThat(toolCall.tool).isEqualTo("Read")
assertThat(toolCall.content).isEqualTo("""{"path":"build.gradle.kts"}""")
}
@Test
fun `normalize tool arguments unwraps string encoded json object`() {
val actual = normalizeToolArgumentsJson("\"{\\\"file_path\\\":\\\"/tmp/fixture.txt\\\"}\"")
assertThat(actual).isEqualTo("""{"file_path":"/tmp/fixture.txt"}""")
}
}

View file

@ -6,6 +6,8 @@ import com.intellij.util.messages.MessageBusConnection
import ee.carlrobert.codegpt.settings.service.FeatureType
import ee.carlrobert.codegpt.settings.service.ModelChangeNotifier
import ee.carlrobert.codegpt.settings.service.ServiceType
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettingsState
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
import java.util.concurrent.atomic.AtomicReference
@ -202,6 +204,30 @@ class ModelSettingsTest : IntegrationTest() {
assertThat(result).isEqualTo(ServiceType.OPENAI)
}
fun `test custom openai models keep configured order for chat and agent`() {
val customServicesSettings = service<CustomServicesSettings>()
val first = createCustomService(name = "First", model = "chat-first")
val second = createCustomService(name = "Second", model = "chat-second")
customServicesSettings.state.services.clear()
customServicesSettings.state.services.add(first)
customServicesSettings.state.services.add(second)
assertThat(customModelNames(FeatureType.CHAT))
.containsExactly("First (chat-first)", "Second (chat-second)")
assertThat(customModelNames(FeatureType.AGENT))
.containsExactly("First (chat-first)", "Second (chat-second)")
customServicesSettings.state.services.clear()
customServicesSettings.state.services.add(second)
customServicesSettings.state.services.add(first)
assertThat(customModelNames(FeatureType.CHAT))
.containsExactly("Second (chat-second)", "First (chat-first)")
assertThat(customModelNames(FeatureType.AGENT))
.containsExactly("Second (chat-second)", "First (chat-first)")
}
fun `test migrateMissingProviderInformation updates missing providers`() {
val state = ModelSettingsState()
val detailsState = ModelDetailsState()
@ -227,4 +253,18 @@ class ModelSettingsTest : IntegrationTest() {
assertThat(modelSettings.getStoredModelForFeature(FeatureType.CHAT)).isEqualTo("unknown-model")
assertThat(modelSettings.getStoredProviderForFeature(FeatureType.CHAT)).isNull()
}
private fun customModelNames(featureType: FeatureType): List<String> {
return modelSettings.getAvailableModels(featureType)
.filter { it.provider == ServiceType.CUSTOM_OPENAI }
.map { it.displayName }
}
private fun createCustomService(name: String, model: String): CustomServiceSettingsState {
return CustomServiceSettingsState().apply {
this.name = name
chatCompletionSettings.body.clear()
chatCompletionSettings.body["model"] = model
}
}
}

View file

@ -0,0 +1,80 @@
package ee.carlrobert.codegpt.ui.textarea
import com.intellij.openapi.components.service
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService
import ee.carlrobert.codegpt.settings.service.FeatureType
import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
import ee.carlrobert.codegpt.ui.textarea.lookup.action.selectedContextFiles
import ee.carlrobert.codegpt.ui.textarea.lookup.group.DiagnosticsGroupItem
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import testsupport.IntegrationTest
class DiagnosticsIntegrationTest : IntegrationTest() {
fun `test diagnostics service should return errors for broken file`() {
val brokenFile = myFixture.configureByText(
"Broken.java",
"class Broken { void test() { int value = ; } }"
).virtualFile
myFixture.doHighlighting()
val report = project.service<ProjectDiagnosticsService>()
.collect(brokenFile, DiagnosticsFilter.ERRORS_ONLY)
assertThat(report.hasDiagnostics).isTrue()
assertThat(report.content).contains("[ERROR]")
assertThat(report.content).contains("Broken.java")
}
fun `test diagnostics service should return empty for clean file`() {
val cleanFile = myFixture.configureByText(
"Clean.java",
"class Clean { void test() { int value = 1; } }"
).virtualFile
myFixture.doHighlighting()
val report = project.service<ProjectDiagnosticsService>()
.collect(cleanFile, DiagnosticsFilter.ERRORS_ONLY)
assertThat(report.hasDiagnostics).isFalse()
assertThat(report.content).isBlank()
assertThat(report.error).isNull()
}
fun `test selectedContextFiles should use selected file context only`() {
val firstFile = myFixture.configureByText("First.java", "class First {}").virtualFile
val secondFile = myFixture.configureByText("Second.java", "class Second {}").virtualFile
val unselectedEditorTag = EditorTagDetails(secondFile).apply { selected = false }
val contextFiles = selectedContextFiles(
listOf(
FileTagDetails(firstFile),
DiagnosticsTagDetails(firstFile, DiagnosticsFilter.ALL),
EditorTagDetails(secondFile),
unselectedEditorTag
)
)
assertThat(contextFiles.map { it.path })
.containsExactly(firstFile.path, secondFile.path)
}
fun `test diagnostics group should be available in agent mode with file context`() {
val file = myFixture.configureByText("AgentFile.java", "class AgentFile {}").virtualFile
val tagManager = TagManager().apply {
addTag(FileTagDetails(file))
}
val groups = SearchManager(project, tagManager, FeatureType.AGENT).getDefaultGroups()
val diagnosticsGroup = groups.filterIsInstance<DiagnosticsGroupItem>().single()
val actions = runBlocking { diagnosticsGroup.getLookupItems("") }
assertThat(diagnosticsGroup.enabled).isTrue()
assertThat(actions.map { it.displayName }).containsExactly("Errors only", "All")
}
}