mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-16 19:44:36 +00:00
Merge branch 'master' of github.com:AlexanderLuck/ProxyAI into AlexanderLuck/master
This commit is contained in:
commit
8fdeba74da
29 changed files with 1712 additions and 275 deletions
|
|
@ -534,6 +534,13 @@ object AgentFactory {
|
|||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
if (SubagentTool.DIAGNOSTICS in selected) tool(
|
||||
DiagnosticsTool(
|
||||
project = project,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
if (SubagentTool.WEB_SEARCH in selected) tool(
|
||||
WebSearchTool(
|
||||
workingDirectory = project.basePath ?: System.getProperty("user.dir"),
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import com.intellij.openapi.components.service
|
|||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStream
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI
|
||||
import ee.carlrobert.codegpt.agent.strategy.CODE_AGENT_COMPRESSION
|
||||
import ee.carlrobert.codegpt.agent.strategy.HistoryCompressionConfig
|
||||
import ee.carlrobert.codegpt.agent.strategy.SingleRunStrategyProvider
|
||||
|
|
@ -34,7 +35,6 @@ import ee.carlrobert.codegpt.settings.hooks.HookManager
|
|||
import ee.carlrobert.codegpt.settings.models.ModelSettings
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
|
||||
import ee.carlrobert.codegpt.settings.skills.SkillDiscoveryService
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.BashPayload
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.ui.approval.ToolApprovalRequest
|
||||
|
|
@ -89,7 +89,7 @@ object ProxyAIAgent {
|
|||
val modelSelection =
|
||||
service<ModelSettings>().getModelSelectionForFeature(FeatureType.AGENT)
|
||||
val skills = project.service<SkillDiscoveryService>().listSkills()
|
||||
val stream = shouldStreamAgentToolLoop(project, provider)
|
||||
val stream = shouldStreamAgentToolLoop(provider)
|
||||
val projectInstructions = loadProjectInstructions(project.basePath)
|
||||
val executor = AgentFactory.createExecutor(provider, events)
|
||||
val pendingMessageQueue = pendingMessages.getOrPut(sessionId) { ArrayDeque() }
|
||||
|
|
@ -268,18 +268,10 @@ object ProxyAIAgent {
|
|||
}
|
||||
|
||||
private fun shouldStreamAgentToolLoop(
|
||||
project: Project,
|
||||
provider: ServiceType,
|
||||
): Boolean {
|
||||
return when (provider) {
|
||||
ServiceType.CUSTOM_OPENAI -> {
|
||||
val selectedServiceId =
|
||||
project.service<ModelSettings>().getStoredModelForFeature(FeatureType.AGENT)
|
||||
project.service<CustomServicesSettings>().state.services
|
||||
.firstOrNull { it.id == selectedServiceId }?.chatCompletionSettings?.shouldStream()
|
||||
?: false
|
||||
}
|
||||
|
||||
ServiceType.CUSTOM_OPENAI -> shouldStreamCustomOpenAI(FeatureType.AGENT)
|
||||
ServiceType.GOOGLE -> false
|
||||
else -> true
|
||||
}
|
||||
|
|
@ -323,6 +315,13 @@ object ProxyAIAgent {
|
|||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
DiagnosticsTool(
|
||||
project = project,
|
||||
sessionId = sessionId,
|
||||
hookManager = hookManager,
|
||||
)
|
||||
)
|
||||
tool(
|
||||
WebSearchTool(
|
||||
workingDirectory = workingDirectory,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ enum class SubagentTool(val id: String, val displayName: String, val isWrite: Bo
|
|||
READ("read", "Read", false),
|
||||
TODO_WRITE("todowrite", "TodoWrite", false),
|
||||
INTELLIJ_SEARCH("intellijsearch", "IntelliJSearch", false),
|
||||
DIAGNOSTICS("diagnostics", "Diagnostics", false),
|
||||
WEB_SEARCH("websearch", "WebSearch", false),
|
||||
WEB_FETCH("webfetch", "WebFetch", false),
|
||||
MCP("MCP", "MCP", false),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
package ee.carlrobert.codegpt.agent
|
||||
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonObject
|
||||
import kotlinx.serialization.json.JsonPrimitive
|
||||
import kotlinx.serialization.json.contentOrNull
|
||||
|
||||
private val toolArgumentsJson = Json {
|
||||
ignoreUnknownKeys = true
|
||||
}
|
||||
|
||||
internal fun normalizeToolArgumentsJson(rawArgs: String?): String? {
|
||||
val payload = rawArgs?.trim().orEmpty()
|
||||
if (payload.isBlank()) {
|
||||
return null
|
||||
}
|
||||
|
||||
val element = runCatching {
|
||||
toolArgumentsJson.parseToJsonElement(payload)
|
||||
}.getOrNull() ?: return null
|
||||
|
||||
return normalizeToolArgumentsElement(element)?.toString()
|
||||
}
|
||||
|
||||
private tailrec fun normalizeToolArgumentsElement(element: JsonElement): JsonObject? {
|
||||
return when (element) {
|
||||
is JsonObject -> element
|
||||
is JsonPrimitive -> {
|
||||
if (!element.isString) {
|
||||
null
|
||||
} else {
|
||||
val nestedPayload = element.contentOrNull?.trim().orEmpty()
|
||||
if (nestedPayload.isBlank()) {
|
||||
null
|
||||
} else {
|
||||
val nested = runCatching {
|
||||
toolArgumentsJson.parseToJsonElement(nestedPayload)
|
||||
}.getOrNull() ?: return null
|
||||
normalizeToolArgumentsElement(nested)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
|
@ -14,6 +14,7 @@ enum class ToolName(val id: String, val aliases: Set<String> = emptySet()) {
|
|||
BASH_OUTPUT("BashOutput"),
|
||||
KILL_SHELL("KillShell"),
|
||||
INTELLIJ_SEARCH("IntelliJSearch"),
|
||||
DIAGNOSTICS("Diagnostics"),
|
||||
WEB_SEARCH("WebSearch"),
|
||||
WEB_FETCH("WebFetch"),
|
||||
MCP("MCP"),
|
||||
|
|
@ -94,6 +95,13 @@ object ToolSpecs {
|
|||
IntelliJSearchTool.Result.serializer()
|
||||
)
|
||||
)
|
||||
register(
|
||||
ToolSpec(
|
||||
ToolName.DIAGNOSTICS,
|
||||
DiagnosticsTool.Args.serializer(),
|
||||
DiagnosticsTool.Result.serializer()
|
||||
)
|
||||
)
|
||||
register(
|
||||
ToolSpec(
|
||||
ToolName.WEB_SEARCH,
|
||||
|
|
@ -193,8 +201,12 @@ object ToolSpecs {
|
|||
if (serializer == null || payload.isBlank()) {
|
||||
return null
|
||||
}
|
||||
val typedSerializer = serializer as KSerializer<Any>
|
||||
return runCatching {
|
||||
json.decodeFromString(serializer as KSerializer<Any>, payload)
|
||||
json.decodeFromString(typedSerializer, payload)
|
||||
}.recoverCatching {
|
||||
val normalized = normalizeToolArgumentsJson(payload) ?: throw it
|
||||
json.decodeFromString(typedSerializer, normalized)
|
||||
}.getOrNull()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import com.fasterxml.jackson.core.type.TypeReference
|
|||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import ee.carlrobert.codegpt.codecompletions.InfillPromptTemplate
|
||||
import ee.carlrobert.codegpt.codecompletions.InfillRequest
|
||||
import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson
|
||||
import ee.carlrobert.codegpt.settings.Placeholder
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceCodeCompletionSettingsState
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
|
|
@ -40,6 +41,7 @@ import kotlinx.serialization.builtins.ListSerializer
|
|||
import kotlinx.serialization.descriptors.elementNames
|
||||
import kotlinx.serialization.json.*
|
||||
import java.net.URI
|
||||
import java.util.UUID
|
||||
import kotlin.io.encoding.ExperimentalEncodingApi
|
||||
import kotlin.time.Clock
|
||||
|
||||
|
|
@ -193,7 +195,7 @@ class CustomOpenAILLMClient(
|
|||
}
|
||||
|
||||
if (isResponsesApi) {
|
||||
return serializeResponsesApiRequest(state, messages, model, tools)
|
||||
return serializeResponsesApiRequest(state, messages, model, tools, toolChoice)
|
||||
}
|
||||
|
||||
val customParams: CustomOpenAIParams = params.toCustomOpenAIParams(state)
|
||||
|
|
@ -245,9 +247,12 @@ class CustomOpenAILLMClient(
|
|||
state: CustomServiceChatCompletionSettingsState,
|
||||
messages: List<OpenAIMessage>,
|
||||
model: LLModel,
|
||||
tools: List<OpenAITool>?
|
||||
tools: List<OpenAITool>?,
|
||||
toolChoice: OpenAIToolChoice?
|
||||
): String {
|
||||
val streamRequest = state.shouldStream()
|
||||
val inputJson = messages.toResponsesApiItemsJson()
|
||||
val prompt = renderCustomOpenAIPrompt(messages, json)
|
||||
|
||||
return buildJsonObject {
|
||||
state.body.forEach { (key, value) ->
|
||||
|
|
@ -256,7 +261,8 @@ class CustomOpenAILLMClient(
|
|||
key = key,
|
||||
value = value,
|
||||
streamRequest = streamRequest,
|
||||
messages = messages,
|
||||
messagesJson = inputJson,
|
||||
prompt = prompt,
|
||||
credential = apiKey,
|
||||
json = json
|
||||
)
|
||||
|
|
@ -264,8 +270,9 @@ class CustomOpenAILLMClient(
|
|||
}
|
||||
put("model", JsonPrimitive(model.id))
|
||||
if (!tools.isNullOrEmpty()) {
|
||||
put("tools", json.encodeToJsonElement(ListSerializer(OpenAITool.serializer()), tools))
|
||||
put("tools", JsonArray(tools.map { it.toResponsesApiToolJson() }))
|
||||
}
|
||||
toolChoice?.toResponsesApiToolChoiceJson()?.let { put("tool_choice", it) }
|
||||
}.toString()
|
||||
}
|
||||
|
||||
|
|
@ -418,10 +425,17 @@ class CustomOpenAILLMClient(
|
|||
}
|
||||
|
||||
override fun decodeStreamingResponse(data: String): CustomOpenAIChatCompletionStreamResponse {
|
||||
val payload = normalizeSsePayload(data)
|
||||
?: return CustomOpenAIChatCompletionStreamResponse(
|
||||
choices = emptyList(),
|
||||
created = 0,
|
||||
id = "",
|
||||
model = ""
|
||||
)
|
||||
if (!isResponsesApi) {
|
||||
return json.decodeFromString(data)
|
||||
return json.decodeFromString(payload)
|
||||
}
|
||||
return adaptResponsesApiStreamEvent(data)
|
||||
return adaptResponsesApiStreamEvent(payload)
|
||||
}
|
||||
|
||||
override fun decodeResponse(data: String): CustomOpenAIChatCompletionResponse {
|
||||
|
|
@ -448,48 +462,42 @@ class CustomOpenAILLMClient(
|
|||
)
|
||||
}
|
||||
|
||||
"response.function_call_arguments.delta" -> {
|
||||
val argsDelta = event["delta"]?.jsonPrimitive?.contentOrNull ?: ""
|
||||
val callId = event["call_id"]?.jsonPrimitive?.contentOrNull ?: ""
|
||||
val index = event["output_index"]?.jsonPrimitive?.intOrNull ?: 0
|
||||
CustomOpenAIStreamChoice(
|
||||
delta = CustomOpenAIStreamDelta(
|
||||
toolCalls = listOf(
|
||||
CustomOpenAIToolCall(
|
||||
id = callId,
|
||||
index = index,
|
||||
function = CustomOpenAIFunction(arguments = argsDelta)
|
||||
"response.function_call_arguments.delta",
|
||||
"response.output_item.added" -> null
|
||||
|
||||
"response.completed" -> {
|
||||
val responseObject = event["response"]?.jsonObject
|
||||
val toolCalls = responseObject?.get("output")
|
||||
?.jsonArray
|
||||
?.mapIndexedNotNull { index, item ->
|
||||
val itemObject = item.jsonObject
|
||||
if (itemObject["type"]?.jsonPrimitive?.contentOrNull != "function_call") {
|
||||
return@mapIndexedNotNull null
|
||||
}
|
||||
|
||||
val rawArguments = when (val arguments = itemObject["arguments"]) {
|
||||
is JsonPrimitive -> arguments.contentOrNull
|
||||
is JsonObject -> arguments.toString()
|
||||
else -> null
|
||||
}
|
||||
CustomOpenAIToolCall(
|
||||
id = itemObject["call_id"]?.jsonPrimitive?.contentOrNull,
|
||||
index = index,
|
||||
function = CustomOpenAIFunction(
|
||||
name = itemObject["name"]?.jsonPrimitive?.contentOrNull,
|
||||
arguments = normalizeToolArgumentsJson(rawArguments) ?: rawArguments.orEmpty()
|
||||
)
|
||||
)
|
||||
}
|
||||
.orEmpty()
|
||||
CustomOpenAIStreamChoice(
|
||||
finishReason = "stop",
|
||||
delta = CustomOpenAIStreamDelta(
|
||||
toolCalls = toolCalls.takeIf { it.isNotEmpty() }
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
"response.output_item.added" -> {
|
||||
val item = event["item"]?.jsonObject
|
||||
if (item?.get("type")?.jsonPrimitive?.contentOrNull == "function_call") {
|
||||
val callId = item["call_id"]?.jsonPrimitive?.contentOrNull ?: ""
|
||||
val name = item["name"]?.jsonPrimitive?.contentOrNull ?: ""
|
||||
val index = event["output_index"]?.jsonPrimitive?.intOrNull ?: 0
|
||||
CustomOpenAIStreamChoice(
|
||||
delta = CustomOpenAIStreamDelta(
|
||||
toolCalls = listOf(
|
||||
CustomOpenAIToolCall(
|
||||
id = callId,
|
||||
index = index,
|
||||
function = CustomOpenAIFunction(name = name)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
} else null
|
||||
}
|
||||
|
||||
"response.completed" -> CustomOpenAIStreamChoice(
|
||||
finishReason = "stop",
|
||||
delta = CustomOpenAIStreamDelta()
|
||||
)
|
||||
|
||||
else -> null
|
||||
}
|
||||
|
||||
|
|
@ -526,12 +534,20 @@ class CustomOpenAILLMClient(
|
|||
}
|
||||
|
||||
"function_call" -> {
|
||||
val rawArguments = when (val arguments = itemObj["arguments"]) {
|
||||
is JsonPrimitive -> arguments.contentOrNull
|
||||
is JsonObject -> arguments.toString()
|
||||
else -> null
|
||||
}
|
||||
toolCallsJson.add(buildJsonObject {
|
||||
put("id", itemObj["call_id"] ?: JsonPrimitive(""))
|
||||
put("type", JsonPrimitive("function"))
|
||||
putJsonObject("function") {
|
||||
put("name", itemObj["name"] ?: JsonPrimitive(""))
|
||||
put("arguments", itemObj["arguments"] ?: JsonPrimitive("{}"))
|
||||
put(
|
||||
"arguments",
|
||||
JsonPrimitive(normalizeToolArgumentsJson(rawArguments) ?: rawArguments ?: "{}")
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -631,12 +647,12 @@ class CustomOpenAILLMClient(
|
|||
}
|
||||
|
||||
assistantToolCalls.forEach { toolCall ->
|
||||
val arguments = normalizeToolArgumentsJson(toolCall.function.arguments) ?: "{}"
|
||||
add(
|
||||
Message.Tool.Call(
|
||||
id = toolCall.id,
|
||||
tool = toolCall.function.name,
|
||||
content = toolCall.function.arguments.takeIf { it.isNotEmpty() }
|
||||
?: "{}",
|
||||
content = arguments,
|
||||
metaInfo = metaInfo
|
||||
)
|
||||
)
|
||||
|
|
@ -706,22 +722,38 @@ internal fun transformCustomOpenAIBodyValue(
|
|||
messages: List<OpenAIMessage>,
|
||||
credential: String,
|
||||
json: Json
|
||||
): JsonElement {
|
||||
return transformCustomOpenAIBodyValue(
|
||||
key = key,
|
||||
value = value,
|
||||
streamRequest = streamRequest,
|
||||
messagesJson = json.encodeToJsonElement(
|
||||
ListSerializer(OpenAIMessage.serializer()),
|
||||
messages
|
||||
),
|
||||
prompt = renderCustomOpenAIPrompt(messages, json),
|
||||
credential = credential,
|
||||
json = json
|
||||
)
|
||||
}
|
||||
|
||||
private fun transformCustomOpenAIBodyValue(
|
||||
key: String?,
|
||||
value: Any?,
|
||||
streamRequest: Boolean,
|
||||
messagesJson: JsonElement,
|
||||
prompt: String,
|
||||
credential: String,
|
||||
json: Json
|
||||
): JsonElement {
|
||||
return when (value) {
|
||||
null -> JsonNull
|
||||
is JsonElement -> value
|
||||
is String -> when {
|
||||
!streamRequest && key == "stream" -> JsonPrimitive(false)
|
||||
CustomServicePlaceholders.isMessages(value) -> json.parseToJsonElement(
|
||||
json.encodeToString(ListSerializer(OpenAIMessage.serializer()), messages)
|
||||
)
|
||||
CustomServicePlaceholders.isMessages(value) -> messagesJson
|
||||
|
||||
CustomServicePlaceholders.isPrompt(value) -> JsonPrimitive(
|
||||
renderCustomOpenAIPrompt(
|
||||
messages,
|
||||
json
|
||||
)
|
||||
)
|
||||
CustomServicePlaceholders.isPrompt(value) -> JsonPrimitive(prompt)
|
||||
|
||||
value.contains($$"$CUSTOM_SERVICE_API_KEY") -> {
|
||||
JsonPrimitive(value.replace($$"$CUSTOM_SERVICE_API_KEY", credential))
|
||||
|
|
@ -740,7 +772,8 @@ internal fun transformCustomOpenAIBodyValue(
|
|||
key = nestedKey.toString(),
|
||||
value = nestedValue,
|
||||
streamRequest = streamRequest,
|
||||
messages = messages,
|
||||
messagesJson = messagesJson,
|
||||
prompt = prompt,
|
||||
credential = credential,
|
||||
json = json
|
||||
)
|
||||
|
|
@ -753,7 +786,8 @@ internal fun transformCustomOpenAIBodyValue(
|
|||
key = null,
|
||||
value = item,
|
||||
streamRequest = streamRequest,
|
||||
messages = messages,
|
||||
messagesJson = messagesJson,
|
||||
prompt = prompt,
|
||||
credential = credential,
|
||||
json = json
|
||||
)
|
||||
|
|
@ -766,7 +800,8 @@ internal fun transformCustomOpenAIBodyValue(
|
|||
key = null,
|
||||
value = item,
|
||||
streamRequest = streamRequest,
|
||||
messages = messages,
|
||||
messagesJson = messagesJson,
|
||||
prompt = prompt,
|
||||
credential = credential,
|
||||
json = json
|
||||
)
|
||||
|
|
@ -784,6 +819,174 @@ internal fun renderCustomOpenAIPrompt(messages: List<OpenAIMessage>, json: Json)
|
|||
}
|
||||
}
|
||||
|
||||
private fun List<OpenAIMessage>.toResponsesApiItemsJson(): JsonArray {
|
||||
return JsonArray(
|
||||
buildList {
|
||||
for (message in this@toResponsesApiItemsJson) {
|
||||
addAll(message.toResponsesApiItemsJson())
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
private fun OpenAIMessage.toResponsesApiItemsJson(): List<JsonObject> {
|
||||
return when (this) {
|
||||
is OpenAIMessage.System, is OpenAIMessage.Developer -> listOf(
|
||||
buildResponsesApiInputMessageItemJson(
|
||||
role = "developer",
|
||||
content = content.toResponsesApiInputContentJson()
|
||||
)
|
||||
)
|
||||
|
||||
is OpenAIMessage.User -> listOf(
|
||||
buildResponsesApiInputMessageItemJson(
|
||||
role = "user",
|
||||
content = content.toResponsesApiInputContentJson()
|
||||
)
|
||||
)
|
||||
|
||||
is OpenAIMessage.Assistant -> buildList {
|
||||
reasoningContent
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?.let { reasoning ->
|
||||
add(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("reasoning"))
|
||||
put("id", JsonPrimitive(UUID.randomUUID().toString()))
|
||||
putJsonArray("summary") {
|
||||
addJsonObject {
|
||||
put("type", JsonPrimitive("summary_text"))
|
||||
put("text", JsonPrimitive(reasoning))
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
content?.text()
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?.let { assistantText ->
|
||||
add(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("message"))
|
||||
put("role", JsonPrimitive("assistant"))
|
||||
putJsonArray("content") {
|
||||
addJsonObject {
|
||||
put("type", JsonPrimitive("output_text"))
|
||||
put("text", JsonPrimitive(assistantText))
|
||||
put("annotations", JsonArray(emptyList()))
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
toolCalls.orEmpty().forEach { toolCall ->
|
||||
val arguments = normalizeToolArgumentsJson(toolCall.function.arguments) ?: "{}"
|
||||
add(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("function_call"))
|
||||
put("arguments", JsonPrimitive(arguments))
|
||||
put("call_id", JsonPrimitive(toolCall.id))
|
||||
put("name", JsonPrimitive(toolCall.function.name))
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
is OpenAIMessage.Tool -> listOf(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("function_call_output"))
|
||||
put("call_id", JsonPrimitive(toolCallId))
|
||||
put("output", JsonPrimitive(content?.text().orEmpty()))
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildResponsesApiInputMessageItemJson(
|
||||
role: String,
|
||||
content: JsonArray
|
||||
): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("type", JsonPrimitive("message"))
|
||||
put("role", JsonPrimitive(role))
|
||||
put("content", content)
|
||||
}
|
||||
}
|
||||
|
||||
private fun Content?.toResponsesApiInputContentJson(): JsonArray {
|
||||
if (this == null) {
|
||||
return JsonArray(emptyList())
|
||||
}
|
||||
|
||||
return when (this) {
|
||||
is Content.Parts -> JsonArray(value.mapNotNull { it.toResponsesApiInputContentJson() })
|
||||
else -> JsonArray(
|
||||
listOf(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("input_text"))
|
||||
put("text", JsonPrimitive(text()))
|
||||
}
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun OpenAIContentPart.toResponsesApiInputContentJson(): JsonObject? {
|
||||
return when (this) {
|
||||
is OpenAIContentPart.Text -> buildJsonObject {
|
||||
put("type", JsonPrimitive("input_text"))
|
||||
put("text", JsonPrimitive(text))
|
||||
}
|
||||
|
||||
is OpenAIContentPart.Image -> buildJsonObject {
|
||||
put("type", JsonPrimitive("input_image"))
|
||||
imageUrl.detail?.let { put("detail", JsonPrimitive(it)) }
|
||||
put("imageUrl", JsonPrimitive(imageUrl.url))
|
||||
}
|
||||
|
||||
is OpenAIContentPart.File -> buildJsonObject {
|
||||
put("type", JsonPrimitive("input_file"))
|
||||
file.fileData?.let { put("fileData", JsonPrimitive(it)) }
|
||||
file.fileId?.let { put("fileId", JsonPrimitive(it)) }
|
||||
file.filename?.let { put("filename", JsonPrimitive(it)) }
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
internal fun OpenAITool.toResponsesApiToolJson(): JsonObject {
|
||||
return buildJsonObject {
|
||||
put("type", JsonPrimitive("function"))
|
||||
put("name", JsonPrimitive(function.name))
|
||||
put(
|
||||
"parameters",
|
||||
function.parameters ?: buildJsonObject {
|
||||
put("type", JsonPrimitive("object"))
|
||||
putJsonObject("properties") {}
|
||||
putJsonArray("required") {}
|
||||
}
|
||||
)
|
||||
function.strict?.let { put("strict", JsonPrimitive(it)) }
|
||||
function.description?.takeIf { it.isNotBlank() }?.let {
|
||||
put("description", JsonPrimitive(it))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun OpenAIToolChoice.toResponsesApiToolChoiceJson(): JsonElement {
|
||||
return when (this) {
|
||||
is OpenAIToolChoice.Function -> buildJsonObject {
|
||||
put("type", JsonPrimitive("function"))
|
||||
put("name", JsonPrimitive(function.name))
|
||||
}
|
||||
|
||||
is OpenAIToolChoice.Mode -> JsonPrimitive(value)
|
||||
}
|
||||
}
|
||||
|
||||
internal object CustomOpenAIChatCompletionRequestSerializer :
|
||||
CustomOpenAIAdditionalPropertiesFlatteningSerializer(CustomOpenAIChatCompletionRequest.serializer())
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
|
||||
|
||||
internal fun shouldStreamCustomOpenAI(featureType: FeatureType): Boolean {
|
||||
return runCatching {
|
||||
service<CustomServicesSettings>()
|
||||
.customServiceStateForFeatureType(featureType)
|
||||
.chatCompletionSettings
|
||||
.shouldStream()
|
||||
}.getOrDefault(false)
|
||||
}
|
||||
|
|
@ -180,8 +180,11 @@ public class ProxyAILLMClient(
|
|||
}
|
||||
}
|
||||
|
||||
override fun decodeStreamingResponse(data: String): ProxyAIChatCompletionStreamResponse =
|
||||
json.decodeFromString(data)
|
||||
override fun decodeStreamingResponse(data: String): ProxyAIChatCompletionStreamResponse {
|
||||
val payload = normalizeSsePayload(data)
|
||||
?: return ProxyAIChatCompletionStreamResponse()
|
||||
return json.decodeFromString(payload)
|
||||
}
|
||||
|
||||
override fun decodeResponse(data: String): ProxyAIChatCompletionResponse =
|
||||
json.decodeFromString(data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
internal fun normalizeSsePayload(rawData: String): String? {
|
||||
val normalized = rawData
|
||||
.replace("\r\n", "\n")
|
||||
.replace('\r', '\n')
|
||||
.trim()
|
||||
if (normalized.isEmpty()) {
|
||||
return null
|
||||
}
|
||||
|
||||
val dataLines = normalized.lineSequence()
|
||||
.mapNotNull { line ->
|
||||
val markerIndex = line.indexOf("data:")
|
||||
if (markerIndex >= 0) {
|
||||
line.substring(markerIndex + "data:".length).trimStart()
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
.filter { it.isNotEmpty() }
|
||||
.toList()
|
||||
|
||||
val payload = if (dataLines.isNotEmpty()) {
|
||||
dataLines.joinToString("\n")
|
||||
} else {
|
||||
normalized
|
||||
}
|
||||
|
||||
return payload.takeUnless { it.isBlank() || it == "[DONE]" }
|
||||
}
|
||||
|
|
@ -23,15 +23,13 @@ import com.intellij.openapi.vfs.LocalFileSystem
|
|||
import ee.carlrobert.codegpt.ReferencedFile
|
||||
import ee.carlrobert.codegpt.agent.AgentEvents
|
||||
import ee.carlrobert.codegpt.agent.MessageWithContext
|
||||
import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson
|
||||
import ee.carlrobert.codegpt.agent.credits.extractCreditsSnapshot
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.toolwindow.agent.AgentCreditsEvent
|
||||
import ee.carlrobert.codegpt.ui.textarea.TagProcessorFactory
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagDetails
|
||||
import kotlinx.serialization.SerializationException
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import java.util.*
|
||||
import ee.carlrobert.codegpt.conversations.message.Message as ChatMessage
|
||||
|
||||
|
|
@ -186,15 +184,11 @@ private suspend fun AIAgentLLMWriteSession.requestResponses(
|
|||
.toMessageResponses()
|
||||
.map {
|
||||
if (it is Message.Tool.Call) {
|
||||
try {
|
||||
// validate json
|
||||
Json.parseToJsonElement(it.content).jsonObject
|
||||
val normalizedArgs = normalizeToolArgumentsJson(it.content) ?: "{}"
|
||||
if (normalizedArgs == it.content) {
|
||||
it
|
||||
} catch (_: SerializationException) {
|
||||
// allows agent to retry the request
|
||||
it.copy(parts = listOf(it.parts[0].copy(text = "{}")))
|
||||
} catch (e: Exception) {
|
||||
throw e
|
||||
} else {
|
||||
it.copy(parts = listOf(it.parts[0].copy(text = normalizedArgs)))
|
||||
}
|
||||
} else {
|
||||
it
|
||||
|
|
|
|||
|
|
@ -0,0 +1,123 @@
|
|||
package ee.carlrobert.codegpt.agent.tools
|
||||
|
||||
import ai.koog.agents.core.tools.annotations.LLMDescription
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
|
||||
import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService
|
||||
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
|
||||
import ee.carlrobert.codegpt.settings.ToolPermissionPolicy
|
||||
import ee.carlrobert.codegpt.settings.hooks.HookManager
|
||||
import ee.carlrobert.codegpt.tokens.truncateToolResult
|
||||
import kotlinx.serialization.SerialName
|
||||
import kotlinx.serialization.Serializable
|
||||
|
||||
class DiagnosticsTool(
|
||||
private val project: Project,
|
||||
private val sessionId: String,
|
||||
private val hookManager: HookManager,
|
||||
) : BaseTool<DiagnosticsTool.Args, DiagnosticsTool.Result>(
|
||||
workingDirectory = project.basePath ?: System.getProperty("user.dir"),
|
||||
argsSerializer = Args.serializer(),
|
||||
resultSerializer = Result.serializer(),
|
||||
name = "Diagnostics",
|
||||
description = """
|
||||
Reads the IDE's current diagnostics for a specific file.
|
||||
|
||||
Use this tool when you need compiler or inspection diagnostics for one file.
|
||||
- The file_path parameter must be an absolute path.
|
||||
- filter='errors_only' returns only errors.
|
||||
- filter='all' returns errors, warnings, weak warnings, and info diagnostics.
|
||||
- Results reflect diagnostics currently available in the IDE for that file.
|
||||
""".trimIndent(),
|
||||
argsClass = Args::class,
|
||||
resultClass = Result::class,
|
||||
hookManager = hookManager,
|
||||
sessionId = sessionId,
|
||||
) {
|
||||
|
||||
@Serializable
|
||||
data class Args(
|
||||
@property:LLMDescription(
|
||||
"The absolute path to the file to inspect. Must be an absolute path."
|
||||
)
|
||||
@SerialName("file_path")
|
||||
val filePath: String,
|
||||
@property:LLMDescription(
|
||||
"Diagnostics filter: 'errors_only' for errors only, or 'all' for all diagnostics."
|
||||
)
|
||||
val filter: DiagnosticsFilter = DiagnosticsFilter.ERRORS_ONLY
|
||||
)
|
||||
|
||||
@Serializable
|
||||
data class Result(
|
||||
@SerialName("file_path")
|
||||
val filePath: String,
|
||||
val filter: DiagnosticsFilter,
|
||||
@SerialName("diagnostic_count")
|
||||
val diagnosticCount: Int = 0,
|
||||
val output: String = "",
|
||||
val error: String? = null
|
||||
)
|
||||
|
||||
override suspend fun doExecute(args: Args): Result {
|
||||
val settingsService = project.service<ProxyAISettingsService>()
|
||||
val decision = settingsService.evaluateToolPermission(this, args.filePath)
|
||||
if (decision == ToolPermissionPolicy.Decision.DENY) {
|
||||
return Result(
|
||||
filePath = args.filePath,
|
||||
filter = args.filter,
|
||||
error = "Access denied by permissions.deny for Diagnostics"
|
||||
)
|
||||
}
|
||||
if (settingsService.hasAllowRulesForTool("Diagnostics")
|
||||
&& decision != ToolPermissionPolicy.Decision.ALLOW
|
||||
) {
|
||||
return Result(
|
||||
filePath = args.filePath,
|
||||
filter = args.filter,
|
||||
error = "Access denied by permissions.allow for Diagnostics"
|
||||
)
|
||||
}
|
||||
if (settingsService.isPathIgnored(args.filePath)) {
|
||||
return Result(
|
||||
filePath = args.filePath,
|
||||
filter = args.filter,
|
||||
error = "File not found: ${args.filePath}"
|
||||
)
|
||||
}
|
||||
|
||||
val diagnosticsService = project.service<ProjectDiagnosticsService>()
|
||||
val virtualFile = diagnosticsService.findVirtualFile(args.filePath)
|
||||
?: return Result(
|
||||
filePath = args.filePath,
|
||||
filter = args.filter,
|
||||
error = "File not found: ${args.filePath}"
|
||||
)
|
||||
|
||||
val report = diagnosticsService.collect(virtualFile, args.filter)
|
||||
return Result(
|
||||
filePath = args.filePath,
|
||||
filter = args.filter,
|
||||
diagnosticCount = report.diagnosticCount,
|
||||
output = report.content.ifBlank { args.filter.emptyMessage() },
|
||||
error = report.error
|
||||
)
|
||||
}
|
||||
|
||||
override fun createDeniedResult(originalArgs: Args, deniedReason: String): Result {
|
||||
return Result(
|
||||
filePath = originalArgs.filePath,
|
||||
filter = originalArgs.filter,
|
||||
error = deniedReason
|
||||
)
|
||||
}
|
||||
|
||||
override fun encodeResultToString(result: Result): String {
|
||||
if (result.error != null) {
|
||||
return "Failed to read diagnostics for '${result.filePath}': ${result.error}"
|
||||
.truncateToolResult()
|
||||
}
|
||||
return result.output.truncateToolResult()
|
||||
}
|
||||
}
|
||||
|
|
@ -8,21 +8,19 @@ import ai.koog.agents.features.eventHandler.feature.handleEvents
|
|||
import ai.koog.agents.features.tokenizer.feature.MessageTokenizer
|
||||
import ai.koog.prompt.dsl.prompt
|
||||
import ai.koog.prompt.tokenizer.Tokenizer
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.agent.AgentEvents
|
||||
import ee.carlrobert.codegpt.agent.MessageWithContext
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStream
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI
|
||||
import ee.carlrobert.codegpt.agent.strategy.CODE_AGENT_COMPRESSION
|
||||
import ee.carlrobert.codegpt.agent.strategy.HistoryCompressionConfig
|
||||
import ee.carlrobert.codegpt.agent.strategy.SingleRunStrategyProvider
|
||||
import ee.carlrobert.codegpt.mcp.McpTool
|
||||
import ee.carlrobert.codegpt.mcp.McpToolAliasResolver
|
||||
import ee.carlrobert.codegpt.mcp.McpToolCallHandler
|
||||
import ee.carlrobert.codegpt.settings.models.ModelSettings
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
|
||||
import ee.carlrobert.codegpt.util.ReasoningFrameTextAdapter
|
||||
import kotlinx.coroutines.*
|
||||
import java.util.*
|
||||
|
|
@ -51,7 +49,7 @@ internal object AgentCompletionRunner : CompletionRunner {
|
|||
val jobRef = AtomicReference<Job?>()
|
||||
val messageBuilder = StringBuilder()
|
||||
val project = request.callParameters.project!!
|
||||
val stream = shouldStreamAgentToolLoop(project, request)
|
||||
val stream = shouldStreamAgentToolLoop(request)
|
||||
val toolCallHandler = project.let { McpToolCallHandler.getInstance(it) }
|
||||
val toolRegistry = createChatToolRegistry(
|
||||
callParameters = request.callParameters,
|
||||
|
|
@ -176,21 +174,13 @@ internal object AgentCompletionRunner : CompletionRunner {
|
|||
}
|
||||
|
||||
internal fun shouldStreamAgentToolLoop(
|
||||
project: Project,
|
||||
request: CompletionRunnerRequest.Chat,
|
||||
): Boolean {
|
||||
val provider = request.serviceType
|
||||
return when (provider) {
|
||||
ServiceType.CUSTOM_OPENAI -> {
|
||||
val selectedServiceId = project.service<ModelSettings>()
|
||||
.getStoredModelForFeature(request.callParameters.featureType)
|
||||
project.service<CustomServicesSettings>().state.services
|
||||
.firstOrNull { it.id == selectedServiceId }
|
||||
?.chatCompletionSettings
|
||||
?.shouldStream()
|
||||
?: false
|
||||
}
|
||||
|
||||
ServiceType.CUSTOM_OPENAI -> shouldStreamCustomOpenAI(
|
||||
request.callParameters.featureType
|
||||
)
|
||||
ServiceType.GOOGLE -> false
|
||||
else -> true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,231 @@
|
|||
package ee.carlrobert.codegpt.diagnostics
|
||||
|
||||
import com.intellij.codeInsight.daemon.impl.DaemonCodeAnalyzerImpl
|
||||
import com.intellij.codeInsight.daemon.impl.HighlightInfo
|
||||
import com.intellij.lang.annotation.HighlightSeverity
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.components.Service
|
||||
import com.intellij.openapi.fileEditor.FileDocumentManager
|
||||
import com.intellij.openapi.project.DumbService
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.util.text.StringUtil
|
||||
import com.intellij.openapi.vfs.LocalFileSystem
|
||||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import com.intellij.psi.PsiDocumentManager
|
||||
import com.intellij.psi.PsiManager
|
||||
import kotlinx.serialization.SerialName
|
||||
import kotlinx.serialization.Serializable
|
||||
|
||||
@Serializable
|
||||
enum class DiagnosticsFilter(val displayName: String) {
|
||||
@SerialName("errors_only")
|
||||
ERRORS_ONLY("Errors only"),
|
||||
|
||||
@SerialName("all")
|
||||
ALL("All");
|
||||
|
||||
fun includes(severity: HighlightSeverity): Boolean {
|
||||
return when (this) {
|
||||
ERRORS_ONLY -> severity == HighlightSeverity.ERROR
|
||||
ALL -> severity == HighlightSeverity.ERROR ||
|
||||
severity == HighlightSeverity.WARNING ||
|
||||
severity == HighlightSeverity.WEAK_WARNING ||
|
||||
severity == HighlightSeverity.INFORMATION
|
||||
}
|
||||
}
|
||||
|
||||
fun minimumSeverity(): HighlightSeverity {
|
||||
return when (this) {
|
||||
ERRORS_ONLY -> HighlightSeverity.ERROR
|
||||
ALL -> HighlightSeverity.INFORMATION
|
||||
}
|
||||
}
|
||||
|
||||
fun emptyMessage(): String {
|
||||
return when (this) {
|
||||
ERRORS_ONLY -> "No errors found."
|
||||
ALL -> "No diagnostics found."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data class DiagnosticsReport(
|
||||
val filePath: String,
|
||||
val filter: DiagnosticsFilter,
|
||||
val content: String = "",
|
||||
val diagnosticCount: Int = 0,
|
||||
val error: String? = null
|
||||
) {
|
||||
val hasDiagnostics: Boolean
|
||||
get() = error == null && diagnosticCount > 0
|
||||
}
|
||||
|
||||
@Service(Service.Level.PROJECT)
|
||||
class ProjectDiagnosticsService(
|
||||
private val project: Project
|
||||
) {
|
||||
|
||||
fun findVirtualFile(filePath: String): VirtualFile? {
|
||||
val normalizedPath = filePath.replace('\\', '/')
|
||||
val fileSystem = LocalFileSystem.getInstance()
|
||||
return fileSystem.findFileByPath(normalizedPath)
|
||||
?: fileSystem.refreshAndFindFileByPath(normalizedPath)
|
||||
}
|
||||
|
||||
fun collect(
|
||||
virtualFile: VirtualFile,
|
||||
filter: DiagnosticsFilter = DiagnosticsFilter.ALL
|
||||
): DiagnosticsReport {
|
||||
return try {
|
||||
var result = DiagnosticsReport(
|
||||
filePath = virtualFile.path,
|
||||
filter = filter
|
||||
)
|
||||
|
||||
ApplicationManager.getApplication().invokeAndWait {
|
||||
result = ApplicationManager.getApplication().runWriteAction<DiagnosticsReport> {
|
||||
DumbService.getInstance(project).runReadActionInSmartMode<DiagnosticsReport> {
|
||||
val document = FileDocumentManager.getInstance().getDocument(virtualFile)
|
||||
?: return@runReadActionInSmartMode DiagnosticsReport(
|
||||
filePath = virtualFile.path,
|
||||
filter = filter,
|
||||
error = "No document found for file."
|
||||
)
|
||||
|
||||
PsiDocumentManager.getInstance(project).commitDocument(document)
|
||||
|
||||
val psiFile = PsiManager.getInstance(project).findFile(virtualFile)
|
||||
?: return@runReadActionInSmartMode DiagnosticsReport(
|
||||
filePath = virtualFile.path,
|
||||
filter = filter,
|
||||
error = "No PSI file found for: ${virtualFile.path}"
|
||||
)
|
||||
|
||||
val rangeHighlights = DaemonCodeAnalyzerImpl.getHighlights(
|
||||
document,
|
||||
filter.minimumSeverity(),
|
||||
project
|
||||
)
|
||||
|
||||
val fileLevel = getFileLevelHighlights(psiFile)
|
||||
val highlights = (rangeHighlights.asSequence() + fileLevel.asSequence())
|
||||
.filter { filter.includes(it.severity) }
|
||||
.mapNotNull { highlight ->
|
||||
extractMessage(highlight)?.let { message ->
|
||||
DiagnosticEntry(highlight, message)
|
||||
}
|
||||
}
|
||||
.distinctBy { Triple(it.message, it.highlight.startOffset, it.highlight.severity) }
|
||||
.sortedWith(
|
||||
compareBy<DiagnosticEntry>(
|
||||
{ severityOrder(it.highlight.severity) },
|
||||
{ it.highlight.startOffset.coerceAtLeast(0) }
|
||||
)
|
||||
)
|
||||
.toList()
|
||||
|
||||
if (highlights.isEmpty()) {
|
||||
return@runReadActionInSmartMode DiagnosticsReport(
|
||||
filePath = virtualFile.path,
|
||||
filter = filter
|
||||
)
|
||||
}
|
||||
|
||||
val maxItems = 200
|
||||
val overflow = (highlights.size - maxItems).coerceAtLeast(0)
|
||||
val shown = highlights.take(maxItems)
|
||||
|
||||
val content = buildString {
|
||||
append("File: ${virtualFile.name}\n")
|
||||
append("Path: ${virtualFile.path}\n")
|
||||
append("Filter: ${filter.displayName}\n\n")
|
||||
|
||||
shown.forEach { entry ->
|
||||
val info = entry.highlight
|
||||
val startOffset = info.startOffset.coerceIn(0, document.textLength)
|
||||
val lineColText =
|
||||
if (info.startOffset >= 0 && document.textLength > 0) {
|
||||
val line = document.getLineNumber(startOffset) + 1
|
||||
val col =
|
||||
startOffset - document.getLineStartOffset(line - 1) + 1
|
||||
"line $line, col $col"
|
||||
} else {
|
||||
"file-level"
|
||||
}
|
||||
|
||||
append("- [${severityLabel(info.severity)}] $lineColText: ${entry.message}\n")
|
||||
}
|
||||
|
||||
if (overflow > 0) {
|
||||
append("... ($overflow more not shown)\n")
|
||||
}
|
||||
}
|
||||
|
||||
DiagnosticsReport(
|
||||
filePath = virtualFile.path,
|
||||
filter = filter,
|
||||
content = content,
|
||||
diagnosticCount = highlights.size
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
} catch (e: Exception) {
|
||||
DiagnosticsReport(
|
||||
filePath = virtualFile.path,
|
||||
filter = filter,
|
||||
error = "Error retrieving diagnostics: ${e.message}"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun getFileLevelHighlights(psiFile: com.intellij.psi.PsiFile): List<HighlightInfo> {
|
||||
return try {
|
||||
val method = DaemonCodeAnalyzerImpl::class.java.methods.firstOrNull {
|
||||
it.name == "getFileLevelHighlights" && it.parameterCount == 2
|
||||
}
|
||||
if (method != null) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
method.invoke(null, project, psiFile) as? List<HighlightInfo> ?: emptyList()
|
||||
} else {
|
||||
emptyList()
|
||||
}
|
||||
} catch (_: Throwable) {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
|
||||
private fun extractMessage(info: HighlightInfo): String? {
|
||||
val rawMessage = info.description ?: info.toolTip ?: ""
|
||||
return StringUtil.removeHtmlTags(rawMessage, false)
|
||||
.trim()
|
||||
.takeIf { it.isNotBlank() }
|
||||
}
|
||||
|
||||
private fun severityLabel(severity: HighlightSeverity): String {
|
||||
return when (severity) {
|
||||
HighlightSeverity.ERROR -> "ERROR"
|
||||
HighlightSeverity.WARNING -> "WARNING"
|
||||
HighlightSeverity.WEAK_WARNING -> "WEAK_WARNING"
|
||||
HighlightSeverity.INFORMATION -> "INFO"
|
||||
else -> severity.toString()
|
||||
}
|
||||
}
|
||||
|
||||
private fun severityOrder(severity: HighlightSeverity): Int {
|
||||
return when (severity) {
|
||||
HighlightSeverity.ERROR -> 0
|
||||
HighlightSeverity.WARNING -> 1
|
||||
HighlightSeverity.WEAK_WARNING -> 2
|
||||
HighlightSeverity.INFORMATION -> 3
|
||||
else -> 4
|
||||
}
|
||||
}
|
||||
|
||||
private data class DiagnosticEntry(
|
||||
val highlight: HighlightInfo,
|
||||
val message: String
|
||||
)
|
||||
}
|
||||
|
|
@ -795,13 +795,18 @@ class InlineEditInlay(private var editor: Editor) : Disposable {
|
|||
|
||||
private fun collectDiagnosticsInfo(): String? {
|
||||
val tags: Set<TagDetails> = tagManager.getTags()
|
||||
val diagnosticsTag =
|
||||
tags.firstOrNull { it.selected && it is DiagnosticsTagDetails } as? DiagnosticsTagDetails
|
||||
?: return null
|
||||
val diagnosticsTags = tags
|
||||
.filter { it.selected && it is DiagnosticsTagDetails }
|
||||
.filterIsInstance<DiagnosticsTagDetails>()
|
||||
if (diagnosticsTags.isEmpty()) {
|
||||
return null
|
||||
}
|
||||
|
||||
val processor = TagProcessorFactory.getProcessor(project, diagnosticsTag)
|
||||
val stringBuilder = StringBuilder()
|
||||
processor.process(Message("", ""), stringBuilder)
|
||||
diagnosticsTags.forEach { diagnosticsTag ->
|
||||
val processor = TagProcessorFactory.getProcessor(project, diagnosticsTag)
|
||||
processor.process(Message("", ""), stringBuilder)
|
||||
}
|
||||
return stringBuilder.toString().takeIf { it.isNotBlank() }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -327,6 +327,8 @@ class CustomServiceForm(
|
|||
.setPreferredSize(Dimension(220, 0))
|
||||
.setAddAction { handleAddAction() }
|
||||
.setRemoveAction { handleRemoveAction() }
|
||||
.setMoveUpAction { handleMoveAction(-1) }
|
||||
.setMoveDownAction { handleMoveAction(1) }
|
||||
.setRemoveActionUpdater {
|
||||
formState.value.services.size > 1
|
||||
}
|
||||
|
|
@ -347,7 +349,6 @@ class CustomServiceForm(
|
|||
handleDuplicateAction()
|
||||
}
|
||||
})
|
||||
.disableUpDownActions()
|
||||
|
||||
private fun handleRemoveAction() {
|
||||
val prevSelectedIndex = customProvidersJBList.selectedIndex
|
||||
|
|
@ -373,6 +374,35 @@ class CustomServiceForm(
|
|||
lastSelectedIndex = -1
|
||||
}
|
||||
|
||||
private fun handleMoveAction(offset: Int) {
|
||||
val selectedIndex = customProvidersJBList.selectedIndex
|
||||
if (selectedIndex == -1) {
|
||||
return
|
||||
}
|
||||
|
||||
val targetIndex = selectedIndex + offset
|
||||
if (targetIndex !in formState.value.services.indices) {
|
||||
return
|
||||
}
|
||||
|
||||
if (lastSelectedIndex != -1 && lastSelectedIndex < formState.value.services.size) {
|
||||
updateStateFromForm(lastSelectedIndex)
|
||||
}
|
||||
|
||||
val movedId = formState.value.services[selectedIndex].id
|
||||
selectedServiceId = movedId
|
||||
pendingSelectedId = movedId
|
||||
|
||||
formState.update { state ->
|
||||
val reordered = state.services.toMutableList()
|
||||
val moved = reordered.removeAt(selectedIndex)
|
||||
reordered.add(targetIndex, moved)
|
||||
state.copy(services = reordered.toImmutableList())
|
||||
}
|
||||
|
||||
lastSelectedIndex = -1
|
||||
}
|
||||
|
||||
private fun handleDuplicateAction() {
|
||||
if (lastSelectedIndex >= 0 && lastSelectedIndex < formState.value.services.size) {
|
||||
updateStateFromForm(lastSelectedIndex)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import ee.carlrobert.codegpt.settings.service.FeatureType
|
|||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.LookupActionItem
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.LookupGroupItem
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.DiagnosticsActionItem
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.ImageActionItem
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.WebActionItem
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.files.IncludeOpenFilesActionItem
|
||||
|
|
@ -44,7 +43,7 @@ class SearchManager(
|
|||
FoldersGroupItem(project, tagManager),
|
||||
if (GitFeatureAvailability.isAvailable) GitGroupItem(project) else null,
|
||||
HistoryGroupItem(),
|
||||
DiagnosticsActionItem(tagManager)
|
||||
DiagnosticsGroupItem(tagManager)
|
||||
).filter { it.enabled }
|
||||
|
||||
private fun getAgentGroups() = listOfNotNull(
|
||||
|
|
@ -52,6 +51,7 @@ class SearchManager(
|
|||
FoldersGroupItem(project, tagManager),
|
||||
if (GitFeatureAvailability.isAvailable) GitGroupItem(project) else null,
|
||||
MCPGroupItem(tagManager, FeatureType.AGENT),
|
||||
DiagnosticsGroupItem(tagManager),
|
||||
ImageActionItem(project, tagManager)
|
||||
).filter { it.enabled }
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ class SearchManager(
|
|||
HistoryGroupItem(),
|
||||
PersonasGroupItem(tagManager),
|
||||
MCPGroupItem(tagManager, featureType ?: FeatureType.CHAT),
|
||||
DiagnosticsActionItem(tagManager),
|
||||
DiagnosticsGroupItem(tagManager),
|
||||
WebActionItem(tagManager),
|
||||
ImageActionItem(project, tagManager)
|
||||
).filter { it.enabled }
|
||||
|
|
|
|||
|
|
@ -1,24 +1,16 @@
|
|||
package ee.carlrobert.codegpt.ui.textarea
|
||||
|
||||
import com.intellij.codeInsight.daemon.impl.DaemonCodeAnalyzerImpl
|
||||
import com.intellij.codeInsight.daemon.impl.HighlightInfo
|
||||
import com.intellij.lang.annotation.HighlightSeverity
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.application.runReadAction
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.fileEditor.FileDocumentManager
|
||||
import com.intellij.openapi.progress.ProgressManager
|
||||
import com.intellij.openapi.project.DumbService
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.util.text.StringUtil
|
||||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import com.intellij.psi.PsiDocumentManager
|
||||
import com.intellij.psi.PsiManager
|
||||
import ee.carlrobert.codegpt.EncodingManager
|
||||
import ee.carlrobert.codegpt.completions.CompletionRequestUtil
|
||||
import ee.carlrobert.codegpt.conversations.Conversation
|
||||
import ee.carlrobert.codegpt.conversations.ConversationsState
|
||||
import ee.carlrobert.codegpt.conversations.message.Message
|
||||
import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService
|
||||
import ee.carlrobert.codegpt.settings.ProxyAISettingsService
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.*
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.HistoryActionItem
|
||||
|
|
@ -270,118 +262,17 @@ class DiagnosticsTagProcessor(
|
|||
private val project: Project,
|
||||
private val tagDetails: DiagnosticsTagDetails,
|
||||
) : TagProcessor {
|
||||
private val diagnosticsService = project.service<ProjectDiagnosticsService>()
|
||||
|
||||
override fun process(message: Message, promptBuilder: StringBuilder) {
|
||||
val diagnostics = diagnosticsService.collect(tagDetails.virtualFile, tagDetails.filter)
|
||||
if (diagnostics.content.isBlank() && diagnostics.error == null) {
|
||||
return
|
||||
}
|
||||
|
||||
promptBuilder
|
||||
.append("\n## Current File Problems\n")
|
||||
.append(getDiagnosticsString(project, tagDetails.virtualFile))
|
||||
.append("\n## ${tagDetails.virtualFile.name} Problems (${tagDetails.filter.displayName})\n")
|
||||
.append(diagnostics.error ?: diagnostics.content)
|
||||
.append("\n")
|
||||
}
|
||||
|
||||
private fun getDiagnosticsString(project: Project, virtualFile: VirtualFile): String {
|
||||
return try {
|
||||
var result = ""
|
||||
ApplicationManager.getApplication().invokeAndWait {
|
||||
result = ApplicationManager.getApplication().runWriteAction<String> {
|
||||
DumbService.getInstance(project).runReadActionInSmartMode<String> {
|
||||
val document = FileDocumentManager.getInstance().getDocument(virtualFile)
|
||||
?: return@runReadActionInSmartMode "No document found for file"
|
||||
|
||||
PsiDocumentManager.getInstance(project).commitDocument(document)
|
||||
|
||||
val psiManager = PsiManager.getInstance(project)
|
||||
val psiFile = psiManager.findFile(virtualFile)
|
||||
?: return@runReadActionInSmartMode "No PSI file found for: ${virtualFile.path}"
|
||||
|
||||
val rangeHighlights =
|
||||
DaemonCodeAnalyzerImpl.getHighlights(
|
||||
document,
|
||||
HighlightSeverity.WEAK_WARNING,
|
||||
project
|
||||
)
|
||||
// TODO: Find a better solution
|
||||
val fileLevel: List<HighlightInfo> = try {
|
||||
val method = DaemonCodeAnalyzerImpl::class.java.methods.firstOrNull {
|
||||
it.name == "getFileLevelHighlights" && it.parameterCount == 2
|
||||
}
|
||||
if (method != null) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
method.invoke(null, project, psiFile) as? List<HighlightInfo>
|
||||
?: emptyList()
|
||||
} else {
|
||||
emptyList()
|
||||
}
|
||||
} catch (_: Throwable) {
|
||||
emptyList()
|
||||
}
|
||||
|
||||
val highlights = (rangeHighlights.asSequence() + fileLevel.asSequence())
|
||||
.distinctBy { Triple(it.description, it.startOffset, it.severity) }
|
||||
.sortedWith(
|
||||
compareBy<HighlightInfo>(
|
||||
{ severityOrder(it.severity) },
|
||||
{ it.startOffset.coerceAtLeast(0) }
|
||||
)
|
||||
)
|
||||
.toList()
|
||||
|
||||
if (highlights.isEmpty()) {
|
||||
return@runReadActionInSmartMode ""
|
||||
}
|
||||
|
||||
val maxItems = 200
|
||||
val overflow = (highlights.size - maxItems).coerceAtLeast(0)
|
||||
val shown = highlights.take(maxItems)
|
||||
|
||||
buildString {
|
||||
append("File: ${virtualFile.name}\n")
|
||||
append("Path: ${virtualFile.path}\n\n")
|
||||
|
||||
shown.forEach { info ->
|
||||
val startOffset = info.startOffset.coerceIn(0, document.textLength)
|
||||
val lineColText =
|
||||
if (info.startOffset >= 0 && document.textLength > 0) {
|
||||
val line = document.getLineNumber(startOffset) + 1
|
||||
val col =
|
||||
startOffset - document.getLineStartOffset(line - 1) + 1
|
||||
"line $line, col $col"
|
||||
} else {
|
||||
"file-level"
|
||||
}
|
||||
|
||||
val rawMessage = info.description ?: info.toolTip ?: ""
|
||||
val message = StringUtil.removeHtmlTags(rawMessage, false).trim()
|
||||
|
||||
val severityLabel = when (info.severity) {
|
||||
HighlightSeverity.ERROR -> "ERROR"
|
||||
HighlightSeverity.WARNING -> "WARNING"
|
||||
HighlightSeverity.WEAK_WARNING -> "WEAK_WARNING"
|
||||
HighlightSeverity.INFORMATION -> "INFO"
|
||||
else -> info.severity.toString()
|
||||
}
|
||||
|
||||
append("- [$severityLabel] $lineColText: $message\n")
|
||||
}
|
||||
|
||||
if (overflow > 0) {
|
||||
append("... ($overflow more not shown)\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
} catch (e: Exception) {
|
||||
"Error retrieving diagnostics: ${e.message}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun severityOrder(severity: HighlightSeverity): Int {
|
||||
return when (severity) {
|
||||
HighlightSeverity.ERROR -> 0
|
||||
HighlightSeverity.WARNING -> 1
|
||||
HighlightSeverity.WEAK_WARNING -> 2
|
||||
HighlightSeverity.INFORMATION -> 3
|
||||
else -> 4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import com.intellij.openapi.editor.SelectionModel
|
|||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import com.intellij.ui.JBColor
|
||||
import ee.carlrobert.codegpt.Icons
|
||||
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
|
||||
import ee.carlrobert.codegpt.mcp.ConnectionStatus
|
||||
import ee.carlrobert.codegpt.mcp.McpResource
|
||||
import ee.carlrobert.codegpt.mcp.McpTool
|
||||
|
|
@ -304,7 +305,9 @@ class CodeAnalyzeTagDetails : TagDetails("Code Analyze", AllIcons.Actions.Depend
|
|||
override fun getTooltipText(): String? = null
|
||||
}
|
||||
|
||||
data class DiagnosticsTagDetails(val virtualFile: VirtualFile) :
|
||||
TagDetails("${virtualFile.name} Problems", AllIcons.General.InspectionsEye) {
|
||||
override fun getTooltipText(): String = virtualFile.path
|
||||
data class DiagnosticsTagDetails(
|
||||
val virtualFile: VirtualFile,
|
||||
val filter: DiagnosticsFilter = DiagnosticsFilter.ALL
|
||||
) : TagDetails("${virtualFile.name} Problems (${filter.displayName})", AllIcons.General.InspectionsEye) {
|
||||
override fun getTooltipText(): String = "${virtualFile.path} (${filter.displayName})"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,41 +0,0 @@
|
|||
package ee.carlrobert.codegpt.ui.textarea.lookup.action
|
||||
|
||||
import com.intellij.icons.AllIcons
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
|
||||
import ee.carlrobert.codegpt.util.EditorUtil
|
||||
|
||||
class DiagnosticsActionItem(
|
||||
private val tagManager: TagManager
|
||||
) : AbstractLookupActionItem() {
|
||||
|
||||
override val displayName: String = "Diagnostics"
|
||||
override val icon = AllIcons.General.InspectionsEye
|
||||
override val enabled: Boolean
|
||||
get() = tagManager.getTags().none { it is DiagnosticsTagDetails } &&
|
||||
tagManager.getTags().any { it is FileTagDetails || it is EditorTagDetails }
|
||||
|
||||
override fun execute(project: Project, userInputPanel: UserInputPanel) {
|
||||
val virtualFile = findVirtualFile(project)
|
||||
virtualFile?.let { file ->
|
||||
userInputPanel.addTag(DiagnosticsTagDetails(file))
|
||||
}
|
||||
}
|
||||
|
||||
private fun findVirtualFile(project: Project): VirtualFile? {
|
||||
val existingFile = tagManager.getTags()
|
||||
.firstNotNullOfOrNull { tag ->
|
||||
when (tag) {
|
||||
is FileTagDetails -> tag.virtualFile
|
||||
is EditorTagDetails -> tag.virtualFile
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
return existingFile ?: EditorUtil.getSelectedEditor(project)?.virtualFile
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
package ee.carlrobert.codegpt.ui.textarea.lookup.action
|
||||
|
||||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorSelectionTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.SelectionTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagDetails
|
||||
|
||||
internal fun selectedContextFiles(tags: Collection<TagDetails>): List<VirtualFile> {
|
||||
return tags.asSequence()
|
||||
.filter { it.selected }
|
||||
.mapNotNull { tag ->
|
||||
when (tag) {
|
||||
is FileTagDetails -> tag.virtualFile
|
||||
is EditorTagDetails -> tag.virtualFile
|
||||
is SelectionTagDetails -> tag.virtualFile
|
||||
is EditorSelectionTagDetails -> tag.virtualFile
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
.distinctBy { it.path }
|
||||
.toList()
|
||||
}
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
package ee.carlrobert.codegpt.ui.textarea.lookup.action
|
||||
|
||||
import com.intellij.icons.AllIcons
|
||||
import com.intellij.notification.NotificationType
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.project.Project
|
||||
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
|
||||
import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService
|
||||
import ee.carlrobert.codegpt.ui.OverlayUtil
|
||||
import ee.carlrobert.codegpt.ui.textarea.UserInputPanel
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
|
||||
|
||||
class DiagnosticsFilterActionItem(
|
||||
private val tagManager: TagManager,
|
||||
private val filter: DiagnosticsFilter
|
||||
) : AbstractLookupActionItem() {
|
||||
|
||||
override val displayName: String = filter.displayName
|
||||
override val icon = AllIcons.General.InspectionsEye
|
||||
|
||||
override fun execute(project: Project, userInputPanel: UserInputPanel) {
|
||||
val diagnosticsService = project.service<ProjectDiagnosticsService>()
|
||||
val files = selectedContextFiles(userInputPanel.getSelectedTags())
|
||||
var matched = false
|
||||
|
||||
files.forEach { virtualFile ->
|
||||
val diagnostics = diagnosticsService.collect(virtualFile, filter)
|
||||
if (!diagnostics.hasDiagnostics) {
|
||||
return@forEach
|
||||
}
|
||||
matched = true
|
||||
|
||||
val newTag = DiagnosticsTagDetails(virtualFile, filter)
|
||||
val existing = tagManager.getTags()
|
||||
.filterIsInstance<DiagnosticsTagDetails>()
|
||||
.firstOrNull { it.virtualFile == virtualFile }
|
||||
|
||||
when {
|
||||
existing == null -> {
|
||||
userInputPanel.addTag(newTag)
|
||||
}
|
||||
|
||||
existing != newTag -> {
|
||||
tagManager.updateTag(existing, newTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!matched) {
|
||||
OverlayUtil.showNotification(
|
||||
filter.emptyMessage().removeSuffix(".") + " in selected context files.",
|
||||
NotificationType.INFORMATION
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package ee.carlrobert.codegpt.ui.textarea.lookup.group
|
||||
|
||||
import com.intellij.icons.AllIcons
|
||||
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.LookupActionItem
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.DiagnosticsFilterActionItem
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.selectedContextFiles
|
||||
|
||||
class DiagnosticsGroupItem(
|
||||
private val tagManager: TagManager
|
||||
) : AbstractLookupGroupItem() {
|
||||
|
||||
override val displayName: String = "Diagnostics"
|
||||
override val icon = AllIcons.General.InspectionsEye
|
||||
override val enabled: Boolean
|
||||
get() = selectedContextFiles(tagManager.getTags()).isNotEmpty()
|
||||
|
||||
override suspend fun getLookupItems(searchText: String): List<LookupActionItem> {
|
||||
return listOf(
|
||||
DiagnosticsFilterActionItem(tagManager, DiagnosticsFilter.ERRORS_ONLY),
|
||||
DiagnosticsFilterActionItem(tagManager, DiagnosticsFilter.ALL)
|
||||
).filter {
|
||||
searchText.isEmpty() || it.displayName.contains(searchText, ignoreCase = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider
|
|||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.CredentialKey.*
|
||||
import ee.carlrobert.codegpt.credentials.CredentialsStore.setCredential
|
||||
import ee.carlrobert.codegpt.agent.clients.shouldStreamCustomOpenAI
|
||||
import ee.carlrobert.codegpt.settings.models.ModelCatalog
|
||||
import ee.carlrobert.codegpt.settings.models.ModelSettings
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
|
|
@ -174,6 +175,144 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
assertThat(result.events.text.toString()).isEqualTo("Hello from Custom OpenAI")
|
||||
}
|
||||
|
||||
fun testCustomOpenAIAgentStreamsWhenStoredSelectionUsesModelId() {
|
||||
val customService = configureCustomOpenAIService(stream = true)
|
||||
service<ModelSettings>().setModel(
|
||||
FeatureType.AGENT,
|
||||
"custom-agent-model",
|
||||
ServiceType.CUSTOM_OPENAI
|
||||
)
|
||||
assertThat(shouldStreamCustomOpenAI(FeatureType.AGENT)).isTrue()
|
||||
expectCustomOpenAI(StreamHttpExchange { request ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/chat/completions")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.body["stream"]).isEqualTo(true)
|
||||
assertThat(extractPromptText(request)).contains("Say hello from streamed Custom OpenAI")
|
||||
chatCompletionChunks("custom-agent-model", "Hello from streamed Custom OpenAI")
|
||||
})
|
||||
|
||||
val result = runAgent(ServiceType.CUSTOM_OPENAI, "Say hello from streamed Custom OpenAI")
|
||||
|
||||
assertThat(customService.id).isNotBlank()
|
||||
assertThat(result.output).isEqualTo("Hello from streamed Custom OpenAI")
|
||||
assertThat(result.events.text.toString()).isEqualTo("Hello from streamed Custom OpenAI")
|
||||
}
|
||||
|
||||
fun testCustomOpenAIResponsesAgentStreamsWhenStoredSelectionUsesModelId() {
|
||||
val customService = configureCustomOpenAIService(
|
||||
path = "/v1/responses",
|
||||
model = "custom-responses-model",
|
||||
stream = true,
|
||||
useResponsesApiBody = true
|
||||
)
|
||||
service<ModelSettings>().setModel(
|
||||
FeatureType.AGENT,
|
||||
"custom-responses-model",
|
||||
ServiceType.CUSTOM_OPENAI
|
||||
)
|
||||
assertThat(shouldStreamCustomOpenAI(FeatureType.AGENT)).isTrue()
|
||||
expectCustomOpenAI(StreamHttpExchange { request ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/responses")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(request.body["stream"]).isEqualTo(true)
|
||||
assertThat(extractPromptText(request)).contains("Say hello from Custom OpenAI Responses")
|
||||
openAiResponsesChunks(
|
||||
model = "custom-responses-model",
|
||||
text = "Hello from Custom OpenAI Responses"
|
||||
)
|
||||
})
|
||||
|
||||
val result = runAgent(ServiceType.CUSTOM_OPENAI, "Say hello from Custom OpenAI Responses")
|
||||
|
||||
assertThat(customService.id).isNotBlank()
|
||||
assertThat(result.output).isEqualTo("Hello from Custom OpenAI Responses")
|
||||
assertThat(result.events.text.toString()).isEqualTo("Hello from Custom OpenAI Responses")
|
||||
}
|
||||
|
||||
fun testCustomOpenAIResponsesAgentSerializesToolHistoryAsResponsesInput() {
|
||||
val fixture = createReadFixture("Custom Responses fixture")
|
||||
configureCustomOpenAIService(
|
||||
path = "/v1/responses",
|
||||
model = "custom-responses-model",
|
||||
stream = false,
|
||||
useResponsesApiBody = true
|
||||
)
|
||||
expectCustomOpenAI(BasicHttpExchange { request ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/responses")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(extractPromptText(request)).contains("Read the fixture and repeat its contents")
|
||||
ResponseEntity(
|
||||
openAiResponsesToolCallResponse(
|
||||
model = "custom-responses-model",
|
||||
toolName = "Read",
|
||||
callId = "call_custom_responses_read",
|
||||
arguments = """{"file_path":"${fixture.path}"}"""
|
||||
)
|
||||
)
|
||||
})
|
||||
expectCustomOpenAI(BasicHttpExchange { request ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/responses")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThatResponsesToolHistory(
|
||||
request = request,
|
||||
toolName = "Read",
|
||||
callId = "call_custom_responses_read",
|
||||
fixture = fixture
|
||||
)
|
||||
ResponseEntity(
|
||||
openAiResponsesResponse(
|
||||
model = "custom-responses-model",
|
||||
text = "Custom Responses read: ${fixture.contents}"
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
val result = runAgent(ServiceType.CUSTOM_OPENAI, "Read the fixture and repeat its contents")
|
||||
|
||||
assertThat(result.output).isEqualTo("Custom Responses read: ${fixture.contents}")
|
||||
assertThat(result.events.text.toString()).isEqualTo("Custom Responses read: ${fixture.contents}")
|
||||
}
|
||||
|
||||
fun testCustomOpenAIResponsesStreamingAgentCompletesToolLoop() {
|
||||
val fixture = createReadFixture("Custom Responses streaming fixture")
|
||||
configureCustomOpenAIService(
|
||||
path = "/v1/responses",
|
||||
model = "custom-responses-model",
|
||||
stream = true,
|
||||
useResponsesApiBody = true
|
||||
)
|
||||
expectCustomOpenAI(StreamHttpExchange { request ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/responses")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThat(extractPromptText(request)).contains("Read the fixture and repeat its contents")
|
||||
openAiResponsesToolCallChunks(
|
||||
model = "custom-responses-model",
|
||||
toolName = "Read",
|
||||
callId = "call_custom_responses_stream_read",
|
||||
arguments = """{"file_path":"${fixture.path}"}"""
|
||||
)
|
||||
})
|
||||
expectCustomOpenAI(StreamHttpExchange { request ->
|
||||
assertThat(request.uri.path).isEqualTo("/v1/responses")
|
||||
assertThat(request.method).isEqualTo("POST")
|
||||
assertThatResponsesToolHistory(
|
||||
request = request,
|
||||
toolName = "Read",
|
||||
callId = "call_custom_responses_stream_read",
|
||||
fixture = fixture
|
||||
)
|
||||
openAiResponsesChunks(
|
||||
model = "custom-responses-model",
|
||||
text = "Custom Responses streaming read: ${fixture.contents}"
|
||||
)
|
||||
})
|
||||
|
||||
val result = runAgent(ServiceType.CUSTOM_OPENAI, "Read the fixture and repeat its contents")
|
||||
|
||||
assertThat(result.output).isEqualTo("Custom Responses streaming read: ${fixture.contents}")
|
||||
assertThat(result.events.text.toString()).isEqualTo("Custom Responses streaming read: ${fixture.contents}")
|
||||
}
|
||||
|
||||
private fun runAgent(
|
||||
provider: ServiceType,
|
||||
userMessage: String
|
||||
|
|
@ -246,9 +385,10 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
)
|
||||
}
|
||||
|
||||
private fun openAiResponsesChunks(): List<String> {
|
||||
val model = "gpt-4.1-mini"
|
||||
val text = "Hello from OpenAI"
|
||||
private fun openAiResponsesChunks(
|
||||
model: String = "gpt-4.1-mini",
|
||||
text: String = "Hello from OpenAI"
|
||||
): List<String> {
|
||||
val chunks = text.chunked(4)
|
||||
return chunks.mapIndexed { index, chunk ->
|
||||
jsonMapResponse(
|
||||
|
|
@ -313,6 +453,154 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
)
|
||||
}
|
||||
|
||||
private fun openAiResponsesToolCallChunks(
|
||||
model: String,
|
||||
toolName: String,
|
||||
callId: String,
|
||||
arguments: String
|
||||
): List<String> {
|
||||
return listOf(
|
||||
jsonMapResponse(
|
||||
e("type", "response.output_item.added"),
|
||||
e(
|
||||
"item",
|
||||
jsonMap(
|
||||
e("type", "function_call"),
|
||||
e("id", "fc_1"),
|
||||
e("call_id", callId),
|
||||
e("name", toolName),
|
||||
e("arguments", "")
|
||||
)
|
||||
),
|
||||
e("output_index", 0),
|
||||
e("sequence_number", 1)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("type", "response.function_call_arguments.delta"),
|
||||
e("item_id", "fc_1"),
|
||||
e("output_index", 0),
|
||||
e("delta", arguments),
|
||||
e("call_id", callId),
|
||||
e("sequence_number", 2)
|
||||
),
|
||||
jsonMapResponse(
|
||||
e("type", "response.completed"),
|
||||
e(
|
||||
"response",
|
||||
jsonMap(
|
||||
e("id", "resp-tool-call"),
|
||||
e("object", "response"),
|
||||
e("created_at", 1),
|
||||
e("model", model),
|
||||
e(
|
||||
"output",
|
||||
jsonArray(
|
||||
jsonMap(
|
||||
e("type", "function_call"),
|
||||
e("id", "fc_1"),
|
||||
e("call_id", callId),
|
||||
e("name", toolName),
|
||||
e("arguments", arguments),
|
||||
e("status", "completed")
|
||||
)
|
||||
)
|
||||
),
|
||||
e("parallel_tool_calls", true),
|
||||
e("status", "completed"),
|
||||
e("text", jsonMap())
|
||||
)
|
||||
),
|
||||
e("sequence_number", 3)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private fun openAiResponsesToolCallResponse(
|
||||
model: String,
|
||||
toolName: String,
|
||||
callId: String,
|
||||
arguments: String
|
||||
): String {
|
||||
return jsonMapResponse(
|
||||
e("id", "resp-tool-call"),
|
||||
e("object", "response"),
|
||||
e("created_at", 1),
|
||||
e("model", model),
|
||||
e(
|
||||
"output",
|
||||
jsonArray(
|
||||
jsonMap(
|
||||
e("type", "function_call"),
|
||||
e("id", "fc_1"),
|
||||
e("call_id", callId),
|
||||
e("name", toolName),
|
||||
e("arguments", arguments),
|
||||
e("status", "completed")
|
||||
)
|
||||
)
|
||||
),
|
||||
e("parallel_tool_calls", true),
|
||||
e("status", "completed"),
|
||||
e("text", jsonMap()),
|
||||
e(
|
||||
"usage",
|
||||
jsonMap(
|
||||
e("input_tokens", 1),
|
||||
e("input_tokens_details", jsonMap("cached_tokens", 0)),
|
||||
e("output_tokens", 1),
|
||||
e("output_tokens_details", jsonMap("reasoning_tokens", 0)),
|
||||
e("total_tokens", 2)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private fun openAiResponsesResponse(
|
||||
model: String,
|
||||
text: String
|
||||
): String {
|
||||
return jsonMapResponse(
|
||||
e("id", "resp-openai-test"),
|
||||
e("object", "response"),
|
||||
e("created_at", 1),
|
||||
e("model", model),
|
||||
e(
|
||||
"output",
|
||||
jsonArray(
|
||||
jsonMap(
|
||||
e("type", "message"),
|
||||
e("id", "msg_1"),
|
||||
e("role", "assistant"),
|
||||
e("status", "completed"),
|
||||
e(
|
||||
"content",
|
||||
jsonArray(
|
||||
jsonMap(
|
||||
e("type", "output_text"),
|
||||
e("text", text),
|
||||
e("annotations", jsonArray())
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
e("parallel_tool_calls", true),
|
||||
e("status", "completed"),
|
||||
e("text", jsonMap()),
|
||||
e(
|
||||
"usage",
|
||||
jsonMap(
|
||||
e("input_tokens", 1),
|
||||
e("input_tokens_details", jsonMap("cached_tokens", 0)),
|
||||
e("output_tokens", 1),
|
||||
e("output_tokens_details", jsonMap("reasoning_tokens", 0)),
|
||||
e("total_tokens", 2)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private fun anthropicTextChunks(text: String): List<String> {
|
||||
return listOf(
|
||||
jsonMapResponse(
|
||||
|
|
@ -573,15 +861,24 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
)
|
||||
}
|
||||
|
||||
private fun configureCustomOpenAIService(): CustomServiceSettingsState {
|
||||
private fun configureCustomOpenAIService(
|
||||
path: String = "/v1/chat/completions",
|
||||
model: String = "custom-agent-model",
|
||||
stream: Boolean = false,
|
||||
useResponsesApiBody: Boolean = false
|
||||
): CustomServiceSettingsState {
|
||||
val settings = service<CustomServicesSettings>()
|
||||
val serviceState = CustomServiceSettingsState().apply {
|
||||
name = "Agent Test Custom Service"
|
||||
chatCompletionSettings.url =
|
||||
System.getProperty("customOpenAI.baseUrl") + "/v1/chat/completions"
|
||||
System.getProperty("customOpenAI.baseUrl") + path
|
||||
chatCompletionSettings.headers.clear()
|
||||
chatCompletionSettings.body.clear()
|
||||
chatCompletionSettings.body["model"] = "custom-agent-model"
|
||||
chatCompletionSettings.body["model"] = model
|
||||
chatCompletionSettings.body["stream"] = stream
|
||||
if (useResponsesApiBody) {
|
||||
chatCompletionSettings.body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
}
|
||||
settings.state.services.clear()
|
||||
settings.state.services.add(serviceState)
|
||||
|
|
@ -616,6 +913,28 @@ class AgentProviderIntegrationTest : IntegrationTest() {
|
|||
}
|
||||
}
|
||||
|
||||
private fun assertThatResponsesToolHistory(
|
||||
request: RequestEntity,
|
||||
toolName: String,
|
||||
callId: String,
|
||||
fixture: ReadFixture
|
||||
) {
|
||||
val input = request.body["input"] as? List<*> ?: error("Expected responses input list")
|
||||
val items = input.mapNotNull { it as? Map<*, *> }
|
||||
|
||||
assertThat(items.any { item ->
|
||||
item["type"] == "function_call" &&
|
||||
item["name"] == toolName &&
|
||||
item["call_id"] == callId &&
|
||||
item["arguments"] == """{"file_path":"${fixture.path}"}"""
|
||||
}).isTrue()
|
||||
assertThat(items.any { item ->
|
||||
item["type"] == "function_call_output" &&
|
||||
item["call_id"] == callId &&
|
||||
(item["output"] as? String).orEmpty().contains(fixture.contents)
|
||||
}).isTrue()
|
||||
}
|
||||
|
||||
private fun extractGooglePromptText(request: RequestEntity): String {
|
||||
val contents = request.body["contents"] as? List<*> ?: return ""
|
||||
return contents.joinToString("\n") { content ->
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
package ee.carlrobert.codegpt.agent
|
||||
|
||||
import ee.carlrobert.codegpt.agent.tools.DiagnosticsTool
|
||||
import ee.carlrobert.codegpt.settings.hooks.HookManager
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
import java.io.File
|
||||
|
||||
class DiagnosticsToolTest : IntegrationTest() {
|
||||
|
||||
fun `test diagnostics tool should be registered in tool specs`() {
|
||||
assertThat(ToolSpecs.find("Diagnostics")).isNotNull()
|
||||
}
|
||||
|
||||
fun `test diagnostics tool should return missing file error`() {
|
||||
val missingPath = File(project.basePath, "does-not-exist.txt").absolutePath
|
||||
val tool = DiagnosticsTool(project, "test-session-id", HookManager(project))
|
||||
|
||||
val result = runBlocking {
|
||||
tool.execute(DiagnosticsTool.Args(missingPath))
|
||||
}
|
||||
|
||||
assertThat(result.error).isEqualTo("File not found: $missingPath")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.Content
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIFunction
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolCall
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice
|
||||
import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolFunction
|
||||
import ai.koog.prompt.llm.LLModel
|
||||
import ai.koog.prompt.params.LLMParams
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.buildJsonArray
|
||||
import kotlinx.serialization.json.buildJsonObject
|
||||
import kotlinx.serialization.json.boolean
|
||||
import kotlinx.serialization.json.jsonArray
|
||||
import kotlinx.serialization.json.jsonObject
|
||||
import kotlinx.serialization.json.jsonPrimitive
|
||||
import kotlinx.serialization.json.put
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
|
||||
class CustomOpenAIResponsesApiSerializationTest {
|
||||
|
||||
private val json = Json {
|
||||
ignoreUnknownKeys = true
|
||||
encodeDefaults = true
|
||||
explicitNulls = false
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `responses api request should encode tools with top level name`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/responses"
|
||||
body.clear()
|
||||
body["stream"] = true
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val serializeMethod = client.javaClass.getDeclaredMethod(
|
||||
"serializeProviderChatRequest",
|
||||
List::class.java,
|
||||
LLModel::class.java,
|
||||
List::class.java,
|
||||
OpenAIToolChoice::class.java,
|
||||
LLMParams::class.java,
|
||||
Boolean::class.javaPrimitiveType
|
||||
)
|
||||
serializeMethod.isAccessible = true
|
||||
|
||||
val payload = serializeMethod.invoke(
|
||||
client,
|
||||
listOf(OpenAIMessage.User(content = Content.Text("hello"))),
|
||||
LLModel(
|
||||
id = "gpt-test",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = emptyList(),
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
),
|
||||
listOf(
|
||||
OpenAITool(
|
||||
OpenAIToolFunction(
|
||||
name = "Diagnostics",
|
||||
description = "Read IDE diagnostics",
|
||||
parameters = buildJsonObject {
|
||||
put("type", "object")
|
||||
put("properties", buildJsonObject {})
|
||||
put("required", buildJsonArray {})
|
||||
},
|
||||
strict = true
|
||||
)
|
||||
)
|
||||
),
|
||||
OpenAIToolChoice.Function(OpenAIToolChoice.FunctionName("Diagnostics")),
|
||||
CustomOpenAIParams(),
|
||||
true
|
||||
) as String
|
||||
|
||||
val request = json.parseToJsonElement(payload).jsonObject
|
||||
val tool = request.getValue("tools").jsonArray.single().jsonObject
|
||||
val toolChoice = request.getValue("tool_choice").jsonObject
|
||||
|
||||
assertThat(tool.getValue("type").jsonPrimitive.content).isEqualTo("function")
|
||||
assertThat(tool.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics")
|
||||
assertThat(tool.getValue("description").jsonPrimitive.content).isEqualTo("Read IDE diagnostics")
|
||||
assertThat(tool.getValue("strict").jsonPrimitive.boolean).isTrue()
|
||||
assertThat(tool).doesNotContainKey("function")
|
||||
assertThat(toolChoice.getValue("type").jsonPrimitive.content).isEqualTo("function")
|
||||
assertThat(toolChoice.getValue("name").jsonPrimitive.content).isEqualTo("Diagnostics")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `responses api request should encode mode tool choice as lowercase string`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/responses"
|
||||
body.clear()
|
||||
body["stream"] = true
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val serializeMethod = client.javaClass.getDeclaredMethod(
|
||||
"serializeProviderChatRequest",
|
||||
List::class.java,
|
||||
LLModel::class.java,
|
||||
List::class.java,
|
||||
OpenAIToolChoice::class.java,
|
||||
LLMParams::class.java,
|
||||
Boolean::class.javaPrimitiveType
|
||||
)
|
||||
serializeMethod.isAccessible = true
|
||||
|
||||
val payload = serializeMethod.invoke(
|
||||
client,
|
||||
listOf(OpenAIMessage.User(content = Content.Text("hello"))),
|
||||
LLModel(
|
||||
id = "gpt-test",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = emptyList(),
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
),
|
||||
emptyList<OpenAITool>(),
|
||||
openAIToolChoiceMode("required"),
|
||||
CustomOpenAIParams(),
|
||||
true
|
||||
) as String
|
||||
|
||||
val request = json.parseToJsonElement(payload).jsonObject
|
||||
|
||||
assertThat(request.getValue("tool_choice").jsonPrimitive.content).isEqualTo("required")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `responses api request should encode agent tool history as input items`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/responses"
|
||||
body.clear()
|
||||
body["stream"] = true
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val serializeMethod = client.javaClass.getDeclaredMethod(
|
||||
"serializeProviderChatRequest",
|
||||
List::class.java,
|
||||
LLModel::class.java,
|
||||
List::class.java,
|
||||
OpenAIToolChoice::class.java,
|
||||
LLMParams::class.java,
|
||||
Boolean::class.javaPrimitiveType
|
||||
)
|
||||
serializeMethod.isAccessible = true
|
||||
|
||||
val payload = serializeMethod.invoke(
|
||||
client,
|
||||
listOf(
|
||||
OpenAIMessage.System(content = Content.Text("System instructions")),
|
||||
OpenAIMessage.User(content = Content.Text("Read the fixture")),
|
||||
OpenAIMessage.Assistant(
|
||||
content = Content.Text("Calling Read"),
|
||||
toolCalls = listOf(
|
||||
OpenAIToolCall(
|
||||
"call_read",
|
||||
OpenAIFunction("Read", "\"{\\\"file_path\\\":\\\"/tmp/fixture.txt\\\"}\"")
|
||||
)
|
||||
)
|
||||
),
|
||||
OpenAIMessage.Tool(
|
||||
content = Content.Text("1\tfixture contents"),
|
||||
toolCallId = "call_read"
|
||||
)
|
||||
),
|
||||
LLModel(
|
||||
id = "gpt-test",
|
||||
provider = CustomOpenAILLMClient.CustomOpenAI,
|
||||
capabilities = emptyList(),
|
||||
contextLength = 128_000,
|
||||
maxOutputTokens = 4_096
|
||||
),
|
||||
emptyList<OpenAITool>(),
|
||||
null,
|
||||
CustomOpenAIParams(),
|
||||
true
|
||||
) as String
|
||||
|
||||
val request = json.parseToJsonElement(payload).jsonObject
|
||||
val input = request.getValue("input").jsonArray
|
||||
|
||||
assertThat(input.map { it.jsonObject.getValue("type").jsonPrimitive.content }).containsExactly(
|
||||
"message",
|
||||
"message",
|
||||
"message",
|
||||
"function_call",
|
||||
"function_call_output"
|
||||
)
|
||||
assertThat(input[0].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("developer")
|
||||
assertThat(input[1].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("user")
|
||||
assertThat(input[2].jsonObject.getValue("role").jsonPrimitive.content).isEqualTo("assistant")
|
||||
assertThat(
|
||||
input[2].jsonObject.getValue("content").jsonArray.single().jsonObject.getValue("text").jsonPrimitive.content
|
||||
).isEqualTo("Calling Read")
|
||||
assertThat(input[3].jsonObject.getValue("name").jsonPrimitive.content).isEqualTo("Read")
|
||||
assertThat(input[3].jsonObject.getValue("call_id").jsonPrimitive.content).isEqualTo("call_read")
|
||||
assertThat(input[3].jsonObject.getValue("arguments").jsonPrimitive.content)
|
||||
.isEqualTo("""{"file_path":"/tmp/fixture.txt"}""")
|
||||
assertThat(input[4].jsonObject.getValue("call_id").jsonPrimitive.content).isEqualTo("call_read")
|
||||
assertThat(input[4].jsonObject.getValue("output").jsonPrimitive.content)
|
||||
.isEqualTo("1\tfixture contents")
|
||||
}
|
||||
|
||||
private fun openAIToolChoiceMode(value: String): OpenAIToolChoice {
|
||||
val modeClass = Class.forName(
|
||||
"ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice\$Mode"
|
||||
)
|
||||
val boxMethod = modeClass.getDeclaredMethod("box-impl", String::class.java)
|
||||
return boxMethod.invoke(null, value) as OpenAIToolChoice
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
package ee.carlrobert.codegpt.agent.clients
|
||||
|
||||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.streaming.StreamFrame
|
||||
import ai.koog.prompt.streaming.toMessageResponses
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceChatCompletionSettingsState
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.asFlow
|
||||
import kotlinx.coroutines.flow.toList
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
|
||||
class StreamingPayloadNormalizerTest {
|
||||
|
||||
@Test
|
||||
fun `should unwrap sse data payload`() {
|
||||
assertThat(normalizeSsePayload("data: {\"id\":\"chunk-1\"}"))
|
||||
.isEqualTo("{\"id\":\"chunk-1\"}")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should ignore done sentinel`() {
|
||||
assertThat(normalizeSsePayload("data: [DONE]")).isNull()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should unwrap multiline sse event frames`() {
|
||||
assertThat(
|
||||
normalizeSsePayload("event: response.created\ndata: {\"type\":\"response.created\"}\n")
|
||||
).isEqualTo("{\"type\":\"response.created\"}")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `should unwrap compact sse frames where event and data share a line`() {
|
||||
assertThat(
|
||||
normalizeSsePayload("event: response.created data: {\"type\":\"response.created\"}")
|
||||
).isEqualTo("{\"type\":\"response.created\"}")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `custom openai client should decode sse wrapped chat completion chunks`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/chat/completions"
|
||||
body["stream"] = true
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val decodeMethod = client.javaClass.getDeclaredMethod("decodeStreamingResponse", String::class.java)
|
||||
decodeMethod.isAccessible = true
|
||||
|
||||
val response = decodeMethod.invoke(
|
||||
client,
|
||||
"data: {\"choices\":[],\"created\":0,\"id\":\"chunk-1\",\"model\":\"test-model\"}"
|
||||
) as CustomOpenAIChatCompletionStreamResponse
|
||||
|
||||
assertThat(response.id).isEqualTo("chunk-1")
|
||||
assertThat(response.model).isEqualTo("test-model")
|
||||
assertThat(response.choices).isEmpty()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `responses api streaming tool call should collapse into one named tool call`() {
|
||||
val state = CustomServiceChatCompletionSettingsState().apply {
|
||||
url = "https://example.com/v1/responses"
|
||||
body.clear()
|
||||
body["stream"] = true
|
||||
body["input"] = "\$OPENAI_MESSAGES"
|
||||
}
|
||||
val client = CustomOpenAILLMClient.fromSettingsState("test-key", state)
|
||||
val decodeMethod = client.javaClass.getDeclaredMethod("decodeStreamingResponse", String::class.java)
|
||||
decodeMethod.isAccessible = true
|
||||
val processMethod = client.javaClass.getDeclaredMethod("processStreamingResponse", Flow::class.java)
|
||||
processMethod.isAccessible = true
|
||||
|
||||
val chunks = listOf(
|
||||
"""{"type":"response.output_item.added","item":{"type":"function_call","id":"fc_1","call_id":"call_diag","name":"Diagnostics","arguments":""},"output_index":0,"sequence_number":1}""",
|
||||
"""{"type":"response.function_call_arguments.delta","item_id":"fc_1","output_index":0,"delta":"{\"file_path\":\"/tmp/mainwindow.cpp\",\"filter\":\"all\"}","call_id":"call_diag","sequence_number":2}""",
|
||||
"""{"type":"response.completed","response":{"id":"resp-tool","object":"response","created_at":1,"model":"test-model","output":[{"type":"function_call","id":"fc_1","call_id":"call_diag","name":"Diagnostics","arguments":"{\"file_path\":\"/tmp/mainwindow.cpp\",\"filter\":\"all\"}","status":"completed"}],"parallel_tool_calls":true,"status":"completed","text":{}},"sequence_number":3}"""
|
||||
).map { payload ->
|
||||
decodeMethod.invoke(client, payload) as CustomOpenAIChatCompletionStreamResponse
|
||||
}
|
||||
|
||||
val responses = runBlocking {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val frames = processMethod.invoke(client, chunks.asFlow()) as Flow<StreamFrame>
|
||||
frames.toList().toMessageResponses()
|
||||
}
|
||||
val toolCall = responses.filterIsInstance<Message.Tool.Call>().single()
|
||||
|
||||
assertThat(toolCall.id).isEqualTo("call_diag")
|
||||
assertThat(toolCall.tool).isEqualTo("Diagnostics")
|
||||
assertThat(toolCall.content).isEqualTo("""{"file_path":"/tmp/mainwindow.cpp","filter":"all"}""")
|
||||
}
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ package ee.carlrobert.codegpt.agent.strategy
|
|||
import ai.koog.prompt.llm.LLMProvider
|
||||
import ai.koog.prompt.message.Message
|
||||
import ai.koog.prompt.message.ResponseMetaInfo
|
||||
import ee.carlrobert.codegpt.agent.normalizeToolArgumentsJson
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import kotlin.test.Test
|
||||
|
||||
|
|
@ -94,4 +95,11 @@ class SingleRunStrategyProviderTest {
|
|||
assertThat(toolCall.tool).isEqualTo("Read")
|
||||
assertThat(toolCall.content).isEqualTo("""{"path":"build.gradle.kts"}""")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `normalize tool arguments unwraps string encoded json object`() {
|
||||
val actual = normalizeToolArgumentsJson("\"{\\\"file_path\\\":\\\"/tmp/fixture.txt\\\"}\"")
|
||||
|
||||
assertThat(actual).isEqualTo("""{"file_path":"/tmp/fixture.txt"}""")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import com.intellij.util.messages.MessageBusConnection
|
|||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.settings.service.ModelChangeNotifier
|
||||
import ee.carlrobert.codegpt.settings.service.ServiceType
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettingsState
|
||||
import ee.carlrobert.codegpt.settings.service.custom.CustomServicesSettings
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
|
@ -202,6 +204,30 @@ class ModelSettingsTest : IntegrationTest() {
|
|||
assertThat(result).isEqualTo(ServiceType.OPENAI)
|
||||
}
|
||||
|
||||
fun `test custom openai models keep configured order for chat and agent`() {
|
||||
val customServicesSettings = service<CustomServicesSettings>()
|
||||
val first = createCustomService(name = "First", model = "chat-first")
|
||||
val second = createCustomService(name = "Second", model = "chat-second")
|
||||
|
||||
customServicesSettings.state.services.clear()
|
||||
customServicesSettings.state.services.add(first)
|
||||
customServicesSettings.state.services.add(second)
|
||||
|
||||
assertThat(customModelNames(FeatureType.CHAT))
|
||||
.containsExactly("First (chat-first)", "Second (chat-second)")
|
||||
assertThat(customModelNames(FeatureType.AGENT))
|
||||
.containsExactly("First (chat-first)", "Second (chat-second)")
|
||||
|
||||
customServicesSettings.state.services.clear()
|
||||
customServicesSettings.state.services.add(second)
|
||||
customServicesSettings.state.services.add(first)
|
||||
|
||||
assertThat(customModelNames(FeatureType.CHAT))
|
||||
.containsExactly("Second (chat-second)", "First (chat-first)")
|
||||
assertThat(customModelNames(FeatureType.AGENT))
|
||||
.containsExactly("Second (chat-second)", "First (chat-first)")
|
||||
}
|
||||
|
||||
fun `test migrateMissingProviderInformation updates missing providers`() {
|
||||
val state = ModelSettingsState()
|
||||
val detailsState = ModelDetailsState()
|
||||
|
|
@ -227,4 +253,18 @@ class ModelSettingsTest : IntegrationTest() {
|
|||
assertThat(modelSettings.getStoredModelForFeature(FeatureType.CHAT)).isEqualTo("unknown-model")
|
||||
assertThat(modelSettings.getStoredProviderForFeature(FeatureType.CHAT)).isNull()
|
||||
}
|
||||
|
||||
private fun customModelNames(featureType: FeatureType): List<String> {
|
||||
return modelSettings.getAvailableModels(featureType)
|
||||
.filter { it.provider == ServiceType.CUSTOM_OPENAI }
|
||||
.map { it.displayName }
|
||||
}
|
||||
|
||||
private fun createCustomService(name: String, model: String): CustomServiceSettingsState {
|
||||
return CustomServiceSettingsState().apply {
|
||||
this.name = name
|
||||
chatCompletionSettings.body.clear()
|
||||
chatCompletionSettings.body["model"] = model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
package ee.carlrobert.codegpt.ui.textarea
|
||||
|
||||
import com.intellij.openapi.components.service
|
||||
import ee.carlrobert.codegpt.diagnostics.DiagnosticsFilter
|
||||
import ee.carlrobert.codegpt.diagnostics.ProjectDiagnosticsService
|
||||
import ee.carlrobert.codegpt.settings.service.FeatureType
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.DiagnosticsTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.EditorTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.FileTagDetails
|
||||
import ee.carlrobert.codegpt.ui.textarea.header.tag.TagManager
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.action.selectedContextFiles
|
||||
import ee.carlrobert.codegpt.ui.textarea.lookup.group.DiagnosticsGroupItem
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
|
||||
class DiagnosticsIntegrationTest : IntegrationTest() {
|
||||
|
||||
fun `test diagnostics service should return errors for broken file`() {
|
||||
val brokenFile = myFixture.configureByText(
|
||||
"Broken.java",
|
||||
"class Broken { void test() { int value = ; } }"
|
||||
).virtualFile
|
||||
myFixture.doHighlighting()
|
||||
|
||||
val report = project.service<ProjectDiagnosticsService>()
|
||||
.collect(brokenFile, DiagnosticsFilter.ERRORS_ONLY)
|
||||
|
||||
assertThat(report.hasDiagnostics).isTrue()
|
||||
assertThat(report.content).contains("[ERROR]")
|
||||
assertThat(report.content).contains("Broken.java")
|
||||
}
|
||||
|
||||
fun `test diagnostics service should return empty for clean file`() {
|
||||
val cleanFile = myFixture.configureByText(
|
||||
"Clean.java",
|
||||
"class Clean { void test() { int value = 1; } }"
|
||||
).virtualFile
|
||||
myFixture.doHighlighting()
|
||||
|
||||
val report = project.service<ProjectDiagnosticsService>()
|
||||
.collect(cleanFile, DiagnosticsFilter.ERRORS_ONLY)
|
||||
|
||||
assertThat(report.hasDiagnostics).isFalse()
|
||||
assertThat(report.content).isBlank()
|
||||
assertThat(report.error).isNull()
|
||||
}
|
||||
|
||||
fun `test selectedContextFiles should use selected file context only`() {
|
||||
val firstFile = myFixture.configureByText("First.java", "class First {}").virtualFile
|
||||
val secondFile = myFixture.configureByText("Second.java", "class Second {}").virtualFile
|
||||
val unselectedEditorTag = EditorTagDetails(secondFile).apply { selected = false }
|
||||
|
||||
val contextFiles = selectedContextFiles(
|
||||
listOf(
|
||||
FileTagDetails(firstFile),
|
||||
DiagnosticsTagDetails(firstFile, DiagnosticsFilter.ALL),
|
||||
EditorTagDetails(secondFile),
|
||||
unselectedEditorTag
|
||||
)
|
||||
)
|
||||
|
||||
assertThat(contextFiles.map { it.path })
|
||||
.containsExactly(firstFile.path, secondFile.path)
|
||||
}
|
||||
|
||||
fun `test diagnostics group should be available in agent mode with file context`() {
|
||||
val file = myFixture.configureByText("AgentFile.java", "class AgentFile {}").virtualFile
|
||||
val tagManager = TagManager().apply {
|
||||
addTag(FileTagDetails(file))
|
||||
}
|
||||
|
||||
val groups = SearchManager(project, tagManager, FeatureType.AGENT).getDefaultGroups()
|
||||
val diagnosticsGroup = groups.filterIsInstance<DiagnosticsGroupItem>().single()
|
||||
val actions = runBlocking { diagnosticsGroup.getLookupItems("") }
|
||||
|
||||
assertThat(diagnosticsGroup.enabled).isTrue()
|
||||
assertThat(actions.map { it.displayName }).containsExactly("Errors only", "All")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue