mirror of
https://github.com/carlrobertoh/ProxyAI.git
synced 2026-05-19 16:28:46 +00:00
fix: improve WriteTool path resolution and error handling
This commit is contained in:
parent
52b34eeb69
commit
55010e53f6
4 changed files with 165 additions and 44 deletions
|
|
@ -2,6 +2,8 @@ package ee.carlrobert.codegpt.agent.tools
|
|||
|
||||
import ai.koog.agents.core.tools.Tool
|
||||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module
|
||||
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
|
||||
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
|
||||
import com.intellij.openapi.diagnostic.thisLogger
|
||||
import ee.carlrobert.codegpt.agent.ToolRunContext
|
||||
|
|
@ -171,7 +173,10 @@ abstract class BaseTool<Args : Any, Result : Any>(
|
|||
val updatedMap = updatedInput as? Map<*, *> ?: return null
|
||||
|
||||
return try {
|
||||
val mapper = ObjectMapper().registerKotlinModule()
|
||||
val mapper = ObjectMapper()
|
||||
.registerKotlinModule()
|
||||
.registerModule(Jdk8Module())
|
||||
.registerModule(JavaTimeModule())
|
||||
|
||||
val argsMap: MutableMap<String, Any?> = mapper.convertValue(
|
||||
currentArgs,
|
||||
|
|
@ -195,4 +200,4 @@ abstract class BaseTool<Args : Any, Result : Any>(
|
|||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import kotlinx.serialization.SerialName
|
|||
import kotlinx.serialization.Serializable
|
||||
import java.io.File
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.file.Paths
|
||||
|
||||
/**
|
||||
* Writes content to files, creating new files or overwriting existing ones.
|
||||
|
|
@ -86,13 +87,6 @@ class WriteTool(
|
|||
}
|
||||
|
||||
override suspend fun doExecute(args: Args): Result {
|
||||
val svc = project.service<ProxyAISettingsService>()
|
||||
if (svc.isPathIgnored(args.filePath)) {
|
||||
return Result.Error(
|
||||
filePath = args.filePath,
|
||||
error = ".proxyai ignore rules block writing to this path"
|
||||
)
|
||||
}
|
||||
if (args.content.isBlank()) {
|
||||
return Result.Error(
|
||||
filePath = args.filePath,
|
||||
|
|
@ -101,23 +95,38 @@ class WriteTool(
|
|||
}
|
||||
|
||||
return try {
|
||||
val file = File(args.filePath)
|
||||
val file = resolveWriteTarget(args.filePath)
|
||||
val filePath = file.absolutePath
|
||||
val svc = project.service<ProxyAISettingsService>()
|
||||
if (svc.isPathIgnored(args.filePath) || svc.isPathIgnored(filePath)) {
|
||||
return Result.Error(
|
||||
filePath = filePath,
|
||||
error = ".proxyai ignore rules block writing to this path"
|
||||
)
|
||||
}
|
||||
|
||||
file.parentFile?.mkdirs()
|
||||
val parent = file.parentFile
|
||||
if (parent != null && !parent.exists() && !parent.mkdirs() && !parent.exists()) {
|
||||
return Result.Error(
|
||||
filePath = filePath,
|
||||
error = "Failed to create parent directories: ${parent.absolutePath}"
|
||||
)
|
||||
}
|
||||
|
||||
val isNewFile = !file.exists()
|
||||
|
||||
if (file.exists() && !file.canWrite()) {
|
||||
return Result.Error(
|
||||
filePath = args.filePath,
|
||||
error = "File is not writable: ${args.filePath}"
|
||||
filePath = filePath,
|
||||
error = "File is not writable: $filePath"
|
||||
)
|
||||
}
|
||||
|
||||
val fileUrl = file.toURI().toString()
|
||||
val virtualFile =
|
||||
VirtualFileManager.getInstance().findFileByUrl("file://${args.filePath}")
|
||||
VirtualFileManager.getInstance().findFileByUrl(fileUrl)
|
||||
?: if (isNewFile) {
|
||||
VirtualFileManager.getInstance().findFileByUrl("file://${args.filePath}")
|
||||
VirtualFileManager.getInstance().refreshAndFindFileByUrl(fileUrl)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
|
|
@ -156,7 +165,7 @@ class WriteTool(
|
|||
val action = if (isNewFile) "created" else "overwritten"
|
||||
|
||||
Result.Success(
|
||||
filePath = args.filePath,
|
||||
filePath = filePath,
|
||||
bytesWritten = bytesWritten,
|
||||
isNewFile = isNewFile,
|
||||
message = "File $action successfully. $bytesWritten bytes written."
|
||||
|
|
@ -170,6 +179,26 @@ class WriteTool(
|
|||
}
|
||||
}
|
||||
|
||||
private fun resolveWriteTarget(requestedPath: String): File {
|
||||
val normalized = requestedPath.replace("\\", "/")
|
||||
val projectBase = project.basePath ?: return File(normalized)
|
||||
val projectPath = Paths.get(projectBase)
|
||||
|
||||
if (!File(normalized).isAbsolute) {
|
||||
return projectPath.resolve(normalized).normalize().toFile()
|
||||
}
|
||||
|
||||
val relativePath = normalized.removePrefix("/")
|
||||
val asProjectRelative = projectPath.resolve(relativePath).normalize().toFile()
|
||||
val asAbsolute = File(normalized)
|
||||
|
||||
if (!asAbsolute.exists() && asProjectRelative.parentFile?.exists() == true) {
|
||||
return asProjectRelative
|
||||
}
|
||||
|
||||
return asAbsolute
|
||||
}
|
||||
|
||||
override fun createDeniedResult(
|
||||
originalArgs: Args,
|
||||
deniedReason: String
|
||||
|
|
|
|||
|
|
@ -2,7 +2,11 @@ package ee.carlrobert.codegpt.settings.hooks
|
|||
|
||||
import com.fasterxml.jackson.databind.JsonNode
|
||||
import com.fasterxml.jackson.databind.ObjectMapper
|
||||
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module
|
||||
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
|
||||
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
|
||||
import com.intellij.execution.configurations.GeneralCommandLine
|
||||
import com.intellij.openapi.util.text.StringUtil
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.slf4j.LoggerFactory
|
||||
|
|
@ -13,7 +17,10 @@ import java.util.concurrent.TimeUnit
|
|||
import java.util.concurrent.TimeoutException
|
||||
|
||||
class HookExecutionService {
|
||||
private val objectMapper = ObjectMapper().registerKotlinModule()
|
||||
private val objectMapper = ObjectMapper()
|
||||
.registerKotlinModule()
|
||||
.registerModule(Jdk8Module())
|
||||
.registerModule(JavaTimeModule())
|
||||
private val logger = LoggerFactory.getLogger(HookExecutionService::class.java)
|
||||
|
||||
suspend fun executeHook(
|
||||
|
|
@ -30,17 +37,8 @@ class HookExecutionService {
|
|||
}
|
||||
|
||||
try {
|
||||
val command = shellCommand(hookConfig.command)
|
||||
val processBuilder = ProcessBuilder(command)
|
||||
|
||||
environment.forEach { (key, value) ->
|
||||
processBuilder.environment()[key] = value
|
||||
}
|
||||
|
||||
val process = processBuilder
|
||||
.directory(File(projectRoot))
|
||||
.redirectErrorStream(true)
|
||||
.start()
|
||||
val process =
|
||||
buildCommandLine(hookConfig.command, environment, projectRoot).createProcess()
|
||||
|
||||
process.outputStream.bufferedWriter().use { writer ->
|
||||
try {
|
||||
|
|
@ -60,36 +58,44 @@ class HookExecutionService {
|
|||
return@withContext HookExecutionResult.Timeout
|
||||
}
|
||||
|
||||
val output = BufferedReader(InputStreamReader(process.inputStream))
|
||||
val stdout = BufferedReader(InputStreamReader(process.inputStream))
|
||||
.use { it.readText() }
|
||||
val stderr = BufferedReader(InputStreamReader(process.errorStream))
|
||||
.use { it.readText() }
|
||||
|
||||
logger.debug("Hook '${hookConfig.command}' output: $output")
|
||||
if (stdout.isNotBlank()) {
|
||||
logger.debug("Hook '${hookConfig.command}' stdout: ${truncateForLog(stdout)}")
|
||||
}
|
||||
if (stderr.isNotBlank()) {
|
||||
logger.debug("Hook '${hookConfig.command}' stderr: ${truncateForLog(stderr)}")
|
||||
}
|
||||
|
||||
when (val exitCode = process.exitValue()) {
|
||||
0 -> {
|
||||
val response = try {
|
||||
val tree = objectMapper.readTree(output)
|
||||
val tree = parseJsonNodeOrNull(stdout)
|
||||
if (tree != null) {
|
||||
HookExecutionResult.Success(responseAsMap(tree))
|
||||
} catch (e: Exception) {
|
||||
logger.warn("Hook '${hookConfig.command}' did not return valid JSON, treating as empty output", e)
|
||||
} else {
|
||||
HookExecutionResult.Success(emptyMap())
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
2 -> {
|
||||
val reason = try {
|
||||
val tree = objectMapper.readTree(output)
|
||||
tree.path("reason").asText()
|
||||
} catch (e: Exception) {
|
||||
logger.warn("Hook '${hookConfig.command}' did not return valid JSON, reason: $e")
|
||||
"Hook denied execution"
|
||||
}
|
||||
val reason = parseJsonNodeOrNull(stdout)?.path("reason")?.asText()
|
||||
?.takeIf { it.isNotBlank() }
|
||||
?: stdout.trim().takeIf { it.isNotBlank() }
|
||||
?: "Hook denied execution"
|
||||
logger.info("Hook '${hookConfig.command}' denied operation: $reason")
|
||||
HookExecutionResult.Denied(reason)
|
||||
}
|
||||
|
||||
else -> {
|
||||
logger.error("Hook '${hookConfig.command}' failed with exit code $exitCode: $output")
|
||||
HookExecutionResult.Failure(output)
|
||||
val error = listOf(stdout.trim(), stderr.trim())
|
||||
.filter { it.isNotBlank() }
|
||||
.joinToString("\n")
|
||||
.ifBlank { "Hook failed with exit code $exitCode" }
|
||||
logger.error("Hook '${hookConfig.command}' failed with exit code $exitCode: $error")
|
||||
HookExecutionResult.Failure(error)
|
||||
}
|
||||
}
|
||||
} catch (e: TimeoutException) {
|
||||
|
|
@ -105,10 +111,26 @@ class HookExecutionService {
|
|||
return when {
|
||||
System.getProperty("os.name").startsWith("Windows") ->
|
||||
listOf("cmd.exe", "/c", command)
|
||||
|
||||
else -> listOf("sh", "-c", command)
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildCommandLine(
|
||||
command: String,
|
||||
environment: Map<String, String>,
|
||||
projectRoot: String
|
||||
): GeneralCommandLine {
|
||||
val shellCommand = shellCommand(command)
|
||||
return GeneralCommandLine().apply {
|
||||
exePath = shellCommand.first()
|
||||
addParameters(shellCommand.drop(1))
|
||||
withWorkDirectory(File(projectRoot))
|
||||
withEnvironment(environment)
|
||||
isRedirectErrorStream = false
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildEnvironment(event: HookEventType, projectRoot: String): Map<String, String> {
|
||||
return mapOf(
|
||||
"PROXYAI_PROJECT_DIR" to projectRoot,
|
||||
|
|
@ -116,6 +138,21 @@ class HookExecutionService {
|
|||
)
|
||||
}
|
||||
|
||||
private fun parseJsonNodeOrNull(raw: String): JsonNode? {
|
||||
if (raw.isBlank()) return null
|
||||
val normalized = raw.dropWhile { ch ->
|
||||
ch.code < 32 && ch != '\n' && ch != '\r' && ch != '\t'
|
||||
}.trimStart()
|
||||
val first = normalized.firstOrNull() ?: return null
|
||||
if (first != '{' && first != '[') return null
|
||||
return runCatching { objectMapper.readTree(normalized) }.getOrNull()
|
||||
}
|
||||
|
||||
private fun truncateForLog(value: String, maxLen: Int = 500): String {
|
||||
val trimmed = value.trim()
|
||||
return StringUtil.shortenTextWithEllipsis(trimmed, maxLen, 0)
|
||||
}
|
||||
|
||||
private fun responseAsMap(node: JsonNode): Map<String, Any> {
|
||||
val result = mutableMapOf<String, Any>()
|
||||
node.fields().forEach { (key, value) ->
|
||||
|
|
@ -137,6 +174,7 @@ class HookExecutionService {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
else -> value.toString()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
package ee.carlrobert.codegpt.agent
|
||||
|
||||
import ee.carlrobert.codegpt.agent.tools.WriteTool
|
||||
import ee.carlrobert.codegpt.settings.hooks.HookManager
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import testsupport.IntegrationTest
|
||||
import java.io.File
|
||||
|
||||
class WriteToolPathResolutionTest : IntegrationTest() {
|
||||
|
||||
fun testWriteCreatesParentDirectoriesForNewFile() {
|
||||
val target = File(project.basePath, "app/components/NewLanding.tsx")
|
||||
|
||||
val result = runBlocking {
|
||||
WriteTool(project, HookManager(project))
|
||||
.execute(
|
||||
WriteTool.Args(
|
||||
target.absolutePath,
|
||||
"export default function NewLanding() {}"
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
assertThat(result).isInstanceOf(WriteTool.Result.Success::class.java)
|
||||
assertThat(target.exists()).isTrue()
|
||||
}
|
||||
|
||||
fun testWriteResolvesLikelyProjectRelativeAbsolutePath() {
|
||||
val pseudoRootDir = "__proxyai_write_test__"
|
||||
val target = File(project.basePath, "$pseudoRootDir/components/NewLanding.tsx")
|
||||
target.parentFile.mkdirs()
|
||||
|
||||
val result = runBlocking {
|
||||
WriteTool(project, HookManager(project))
|
||||
.execute(
|
||||
WriteTool.Args(
|
||||
"/$pseudoRootDir/components/NewLanding.tsx",
|
||||
"export default function NewLanding() {}"
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
assertThat(result).isInstanceOf(WriteTool.Result.Success::class.java)
|
||||
val success = result as WriteTool.Result.Success
|
||||
assertThat(success.filePath).isEqualTo(target.absolutePath)
|
||||
assertThat(target.readText()).contains("NewLanding")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue