feat(file_utils): robust path handling and safe directory listing (#1195)

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: bytecii <bytecii@users.noreply.github.com>
This commit is contained in:
Phives 2026-02-22 04:41:18 -05:00 committed by GitHub
parent e76568c1e1
commit 6776a90a6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 702 additions and 93 deletions

View file

@ -33,3 +33,7 @@ class NoPermissionException(Exception):
class ProgramException(Exception):
def __init__(self, text: str):
self.text = text
class PathEscapesBaseError(ValueError):
"""Raised when a path resolves outside its allowed base directory."""

View file

@ -15,7 +15,6 @@
import asyncio
import datetime
import logging
import os
import platform
from pathlib import Path
from typing import Any
@ -58,7 +57,7 @@ from app.service.task import (
set_current_task_id,
)
from app.utils.event_loop_utils import set_main_event_loop
from app.utils.file_utils import get_working_directory
from app.utils.file_utils import get_working_directory, list_files
from app.utils.server.sync_step import sync_step
from app.utils.telemetry.workforce_metrics import WorkforceMetricsCallback
from app.utils.workforce import Workforce
@ -92,41 +91,24 @@ def format_task_context(
# Skip file listing if requested
if not skip_files:
working_directory = task_data.get("working_directory")
skip_ext = (".pyc", ".tmp")
if working_directory:
try:
if os.path.exists(working_directory):
generated_files = []
for root, dirs, files in os.walk(working_directory):
dirs[:] = [
d
for d in dirs
if not d.startswith(".")
and d
not in ["node_modules", "__pycache__", "venv"]
]
for file in files:
if not file.startswith(".") and not file.endswith(
skip_ext
):
file_path = os.path.join(root, file)
absolute_path = os.path.abspath(file_path)
# Only add if not seen before
if (
seen_files is None
or absolute_path not in seen_files
):
generated_files.append(absolute_path)
if seen_files is not None:
seen_files.add(absolute_path)
if generated_files:
context_parts.append(
"Generated Files from Previous Task:"
)
for file_path in sorted(generated_files):
context_parts.append(f" - {file_path}")
generated_files = list_files(
working_directory,
base=working_directory,
skip_dirs={"node_modules", "__pycache__", "venv"},
skip_extensions=(".pyc", ".tmp"),
skip_prefix=".",
)
if seen_files is not None:
generated_files = [
p for p in generated_files if p not in seen_files
]
seen_files.update(generated_files)
if generated_files:
context_parts.append("Generated Files from Previous Task:")
for file_path in sorted(generated_files):
context_parts.append(f" - {file_path}")
except Exception as e:
logger.warning(f"Failed to collect generated files: {e}")
@ -172,31 +154,20 @@ def collect_previous_task_context(
f"Previous Task Result:\n{previous_task_result}\n"
)
# Collect generated files from working directory
# Collect generated files from working directory (safe listing)
try:
if os.path.exists(working_directory):
generated_files = []
for root, dirs, files in os.walk(working_directory):
dirs[:] = [
d
for d in dirs
if not d.startswith(".")
and d not in ["node_modules", "__pycache__", "venv"]
]
skip_ext = (".pyc", ".tmp")
for file in files:
if not file.startswith(".") and not file.endswith(
skip_ext
):
file_path = os.path.join(root, file)
absolute_path = os.path.abspath(file_path)
generated_files.append(absolute_path)
if generated_files:
context_parts.append("Generated Files from Previous Task:")
for file_path in sorted(generated_files):
context_parts.append(f" - {file_path}")
context_parts.append("")
generated_files = list_files(
working_directory,
base=working_directory,
skip_dirs={"node_modules", "__pycache__", "venv"},
skip_extensions=(".pyc", ".tmp"),
skip_prefix=".",
)
if generated_files:
context_parts.append("Generated Files from Previous Task:")
for file_path in sorted(generated_files):
context_parts.append(f" - {file_path}")
context_parts.append("")
except Exception as e:
logger.warning(f"Failed to collect generated files: {e}")
@ -272,30 +243,21 @@ def build_conversation_context(
context += f"Assistant: {entry['content']}\n\n"
if working_directories:
all_generated_files = set() # Use set to avoid duplicates
all_generated_files: set[str] = set()
for working_directory in working_directories:
try:
if os.path.exists(working_directory):
for root, dirs, files in os.walk(working_directory):
dirs[:] = [
d
for d in dirs
if not d.startswith(".")
and d
not in ["node_modules", "__pycache__", "venv"]
]
for file in files:
if not file.startswith(
"."
) and not file.endswith((".pyc", ".tmp")):
file_path = os.path.join(root, file)
absolute_path = os.path.abspath(file_path)
all_generated_files.add(absolute_path)
files_list = list_files(
working_directory,
base=working_directory,
skip_dirs={"node_modules", "__pycache__", "venv"},
skip_extensions=(".pyc", ".tmp"),
skip_prefix=".",
)
all_generated_files.update(files_list)
except Exception as e:
logger.warning(
"Failed to collect generated "
f"files from {working_directory}"
f": {e}"
f"files from {working_directory}: {e}"
)
if all_generated_files:

View file

@ -11,16 +11,254 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
"""File system utilities."""
"""File system utilities with robust path handling and edge-case safety."""
import logging
import os
import platform
import shutil
from pathlib import Path
from app.component.environment import env
from app.exception.exception import PathEscapesBaseError
from app.model.chat import Chat
logger = logging.getLogger(__name__)
logger = logging.getLogger("file_utils")
# Windows has a 260-character path limit unless long path support is enabled
MAX_PATH_LENGTH_WIN = 260
MAX_PATH_LENGTH_UNIX = 4096
# Default directory names to skip when listing (list_files)
DEFAULT_SKIP_DIRS = frozenset(
{".git", "node_modules", "__pycache__", "venv", ".venv"}
)
# Default file extensions to skip when listing (list_files)
DEFAULT_SKIP_EXTENSIONS: tuple[str, ...] = (".pyc", ".tmp", ".temp")
def _max_path_length() -> int:
"""Return the platform-appropriate max path length for validation."""
return (
MAX_PATH_LENGTH_WIN
if platform.system() == "Windows"
else MAX_PATH_LENGTH_UNIX
)
def _is_under_base(path_real: str, base_real: str) -> bool:
"""Return True if path_real is at or under base_real (both already realpath'd)."""
base = base_real.rstrip(os.sep)
return path_real.startswith(base + os.sep) or path_real == base
def _should_skip(
name: str,
skip_prefix: str,
skip_extensions: tuple[str, ...] = (),
) -> bool:
"""Return True if a file or directory name should be excluded from listing."""
if name.startswith(skip_prefix):
return True
return any(name.endswith(ext) for ext in skip_extensions)
def join_under_base(base: str, *parts: str) -> str | None:
"""Join path parts onto base, ensuring the result stays under base.
Args:
base (str): Base directory; must exist as a directory.
*parts (str): Path components to join onto base.
Returns:
str | None: Resolved absolute path if valid and under base, None otherwise.
"""
if not base or not base.strip():
return None
try:
base_resolved = Path(base).resolve()
if not base_resolved.is_dir():
return None
combined = base_resolved
for p in parts:
if p is None or (isinstance(p, str) and ".." in p.split(os.sep)):
return None
combined = combined / p
resolved = combined.resolve()
try:
resolved.relative_to(base_resolved)
except ValueError:
return None
if len(str(resolved)) > _max_path_length():
return None
return str(resolved)
except (OSError, RuntimeError) as e:
logger.debug("join_under_base failed: %s", e)
return None
def is_safe_path(path: str, base: str) -> bool:
"""Return True if path is under base (realpath) and within path length limits.
Args:
path (str): Path to validate (file or directory).
base (str): Base directory that path must be under.
Returns:
bool: True if path resolves under base and within path length limits.
"""
if not path or not base:
return False
try:
base_real = os.path.realpath(base)
path_real = os.path.realpath(path)
if not _is_under_base(path_real, base_real):
return False
return len(path_real) <= _max_path_length()
except (OSError, RuntimeError):
return False
def resolve_under_base(path: str, base: str) -> str:
"""Resolve path and verify it stays under base. Raises if it escapes.
Args:
path (str): Path to resolve (relative or absolute).
base (str): Base directory that path must be confined to.
Returns:
str: Resolved real path.
Raises:
ValueError: If path is empty or whitespace.
PathEscapesBaseError: If path resolves outside base or exceeds path length.
OSError: If path resolution fails due to filesystem error.
"""
if not path or not path.strip():
raise ValueError(f"Path must be non-empty, got: {path!r}")
base_abs = os.path.abspath(base)
if not os.path.isdir(base_abs):
raise PathEscapesBaseError(f"Base is not a directory: {base!r}")
resolved = os.path.normpath(os.path.join(base_abs, path))
resolved_real = os.path.realpath(resolved)
base_real = os.path.realpath(base_abs)
if not _is_under_base(resolved_real, base_real):
raise PathEscapesBaseError(
f"Path escapes base: path={path!r} base={base!r}"
)
if len(resolved_real) > _max_path_length():
raise PathEscapesBaseError(
f"Path exceeds max length ({len(resolved_real)}): {resolved_real!r}"
)
return resolved_real
def normalize_working_path(path: str | Path | None) -> str:
"""
Normalize and validate a working directory path using pathlib.
Requires a non-empty path; raises ValueError if path is None or empty.
For invalid or nonexistent paths, falls back to parent or user home.
Args:
path: Working directory path (str or Path). Must be specified.
Returns:
Absolute, resolved directory path as a string.
Raises:
ValueError: If path is None or empty/whitespace.
"""
if path is None or not str(path).strip():
raise ValueError("Working directory path must be specified.")
p = Path(path).expanduser().resolve()
try:
if len(str(p)) > _max_path_length():
logger.warning("Working path too long, using parent: %s", p)
p = p.parent
if not p.exists():
if p.parent.exists() and p.parent.is_dir():
return str(p.parent)
return str(Path.home())
if p.is_dir():
return str(p)
return str(p.parent)
except (OSError, RuntimeError) as e:
logger.warning("Invalid working path %r: %s", path, e)
return str(Path.home())
def list_files(
dir_path: str,
base: str | None = None,
*,
max_entries: int = 10_000,
skip_dirs: set[str] | None = None,
skip_extensions: tuple[str, ...] = DEFAULT_SKIP_EXTENSIONS,
skip_prefix: str = ".",
) -> list[str]:
"""List files under dir_path with optional base confinement and filters.
If base is set, only returns paths that resolve under base (no traversal).
Args:
dir_path (str): Directory to list; must resolve under base when base is set.
base (str | None): Confinement base (default: cwd). Paths outside this are excluded.
max_entries (int): Maximum number of file paths to return.
skip_dirs (set[str] | None): Directory names to skip (default: DEFAULT_SKIP_DIRS).
skip_extensions (tuple[str, ...]): File extensions to skip (default: DEFAULT_SKIP_EXTENSIONS).
skip_prefix (str): Skip dirs/files whose name starts with this prefix.
Returns:
List of real absolute file paths under dir_path (subject to filters and max_entries).
"""
if not dir_path or not dir_path.strip():
logger.warning("list_files: empty dir_path")
return []
resolve_base = base if base else os.getcwd()
try:
resolved_dir = resolve_under_base(dir_path, resolve_base)
except PathEscapesBaseError as e:
logger.warning("list_files: %s", e)
return []
except (ValueError, OSError) as e:
logger.warning("list_files: invalid dir_path %r: %s", dir_path, e)
return []
try:
if not os.path.isdir(resolved_dir):
return []
except OSError:
return []
base_real = os.path.realpath(resolve_base)
skip_dirs = set(DEFAULT_SKIP_DIRS) if skip_dirs is None else skip_dirs
result: list[str] = []
try:
for root, dirs, files in os.walk(resolved_dir, followlinks=False):
dirs[:] = [
d
for d in dirs
if d not in skip_dirs and not _should_skip(d, skip_prefix)
]
for name in files:
if _should_skip(name, skip_prefix, skip_extensions):
continue
try:
file_path = os.path.join(root, name)
real_path = os.path.realpath(file_path)
if not _is_under_base(real_path, base_real):
logger.debug(
"list_files: skipping %r (escapes base)", file_path
)
continue
result.append(real_path)
if len(result) >= max_entries:
logger.debug(
"list_files hit max_entries=%d", max_entries
)
return result
except OSError:
continue
except OSError as e:
logger.warning("list_files failed for %r: %s", dir_path, e)
return result
def get_working_directory(options: Chat, task_lock=None) -> str:
@ -28,20 +266,24 @@ def get_working_directory(options: Chat, task_lock=None) -> str:
Get the correct working directory for file operations.
First checks if there's an updated path from improve API call,
then falls back to environment variable or default path.
Result is normalized for safety (traversal, length, existence).
"""
if not task_lock:
from app.service.task import get_task_lock_if_exists
task_lock = get_task_lock_if_exists(options.project_id)
raw: Path | str
if (
task_lock
and hasattr(task_lock, "new_folder_path")
and task_lock.new_folder_path
):
return str(task_lock.new_folder_path)
raw = Path(task_lock.new_folder_path)
else:
return env("file_save_path", options.file_save_path())
raw = Path(env("file_save_path", options.file_save_path()))
return normalize_working_path(raw)
def sync_eigent_skills_to_project(working_directory: str) -> None:

View file

@ -25,6 +25,7 @@ from app.service.chat_service import (
collect_previous_task_context,
construct_workforce,
format_agent_description,
format_task_context,
install_mcp,
new_agent_model,
question_confirm,
@ -44,6 +45,36 @@ from app.service.task import (
)
@pytest.mark.unit
class TestFormatTaskContext:
"""Test cases for format_task_context function."""
def test_format_task_context_with_working_directory_and_files(
self, temp_dir
):
"""Test format_task_context lists generated files via list_files."""
(temp_dir / "output.txt").write_text("content")
task_data = {
"task_content": "Create file",
"task_result": "Done",
"working_directory": str(temp_dir),
}
result = format_task_context(task_data, skip_files=False)
assert "Previous Task: Create file" in result
assert "output.txt" in result
assert "Generated Files from Previous Task:" in result
def test_format_task_context_skip_files(self, temp_dir):
"""Test format_task_context with skip_files=True omits file listing."""
task_data = {
"task_content": "Task",
"task_result": "Result",
"working_directory": str(temp_dir),
}
result = format_task_context(task_data, skip_files=True)
assert "Generated Files from Previous Task:" not in result
@pytest.mark.unit
class TestCollectPreviousTaskContext:
"""Test cases for collect_previous_task_context function."""
@ -230,14 +261,14 @@ class TestCollectPreviousTaskContext:
assert "Previous Task:" not in result
assert "Previous Task Result:" not in result
@patch("app.service.chat_service.logger")
@patch("app.utils.file_utils.logger")
def test_collect_previous_task_context_file_system_error(
self, mock_logger, temp_dir
):
"""Test collect_previous_task_context handles file system errors gracefully."""
working_directory = str(temp_dir)
# Mock os.walk to raise an exception
# Mock os.walk to raise an exception (used inside list_files)
with patch("os.walk", side_effect=PermissionError("Access denied")):
result = collect_previous_task_context(
working_directory=working_directory,
@ -251,7 +282,7 @@ class TestCollectPreviousTaskContext:
assert "Test task" in result
assert "Generated Files from Previous Task:" not in result
# Should log warning
# Warning is logged by file_utils.list_files
mock_logger.warning.assert_called_once()
def test_collect_previous_task_context_relative_paths(self, temp_dir):
@ -971,7 +1002,7 @@ class TestChatServiceErrorCases:
working_directory = str(temp_dir)
with patch("os.walk", side_effect=OSError("Permission denied")):
with patch("app.service.chat_service.logger") as mock_logger:
with patch("app.utils.file_utils.logger") as mock_logger:
result = collect_previous_task_context(
working_directory=working_directory,
previous_task_content="Test task",
@ -988,7 +1019,7 @@ class TestChatServiceErrorCases:
# Should not include file listing
assert "Generated Files from Previous Task:" not in result
# Should log warning
# Warning is logged by file_utils.list_files
mock_logger.warning.assert_called_once()
def test_collect_previous_task_context_abspath_used(self, temp_dir):

View file

@ -0,0 +1,360 @@
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from app.exception.exception import PathEscapesBaseError
from app.utils.file_utils import (
DEFAULT_SKIP_DIRS,
get_working_directory,
is_safe_path,
join_under_base,
list_files,
normalize_working_path,
resolve_under_base,
)
def test_normalize_working_path_none_raises():
with pytest.raises(ValueError, match="must be specified"):
normalize_working_path(None)
def test_normalize_working_path_empty_string_raises():
with pytest.raises(ValueError, match="must be specified"):
normalize_working_path("")
def test_normalize_working_path_whitespace_raises():
with pytest.raises(ValueError, match="must be specified"):
normalize_working_path(" ")
def test_normalize_working_path_valid_dir_returns_absolute(temp_dir):
result = normalize_working_path(str(temp_dir))
assert os.path.isabs(result)
assert os.path.isdir(result)
assert os.path.realpath(result) == os.path.realpath(str(temp_dir))
def test_normalize_working_path_accepts_path_object(temp_dir):
result = normalize_working_path(temp_dir)
assert os.path.isabs(result)
assert os.path.realpath(result) == os.path.realpath(str(temp_dir))
def test_normalize_working_path_file_returns_parent(temp_dir):
f = temp_dir / "some_file.txt"
f.write_text("x")
result = normalize_working_path(str(f))
assert os.path.realpath(result) == os.path.realpath(str(temp_dir))
def test_normalize_working_path_nonexistent_falls_back_to_parent(temp_dir):
missing = temp_dir / "does_not_exist"
result = normalize_working_path(str(missing))
assert os.path.realpath(result) == os.path.realpath(str(temp_dir))
def test_normalize_working_path_nonexistent_deep_falls_back_to_home(tmp_path):
missing = tmp_path / "a" / "b" / "c"
result = normalize_working_path(str(missing))
assert os.path.isdir(result)
def test_resolve_under_base_empty_raises_value_error(temp_dir):
with pytest.raises(ValueError):
resolve_under_base("", str(temp_dir))
def test_resolve_under_base_whitespace_raises_value_error(temp_dir):
with pytest.raises(ValueError):
resolve_under_base(" ", str(temp_dir))
def test_resolve_under_base_relative_returns_realpath(temp_dir):
sub = temp_dir / "sub"
sub.mkdir()
result = resolve_under_base("sub", str(temp_dir))
assert result == os.path.realpath(str(sub))
def test_resolve_under_base_absolute_under_base(temp_dir):
sub = temp_dir / "deep"
sub.mkdir()
result = resolve_under_base(str(sub), str(temp_dir))
assert result == os.path.realpath(str(sub))
def test_resolve_under_base_traversal_raises(temp_dir):
with pytest.raises(PathEscapesBaseError):
resolve_under_base("..", str(temp_dir))
def test_resolve_under_base_nested_traversal_raises(temp_dir):
other = temp_dir / "a" / "b"
other.mkdir(parents=True)
with pytest.raises(PathEscapesBaseError):
resolve_under_base("../..", str(other))
def test_resolve_under_base_nonexistent_base_raises(temp_dir):
with pytest.raises(PathEscapesBaseError, match="not a directory"):
resolve_under_base("sub", str(temp_dir / "no_such_dir"))
def test_resolve_under_base_absolute_outside_raises(temp_dir):
with pytest.raises(PathEscapesBaseError):
resolve_under_base("/etc/passwd", str(temp_dir))
def test_join_under_base_empty_base_returns_none():
assert join_under_base("", "a") is None
def test_join_under_base_whitespace_base_returns_none():
assert join_under_base(" ", "a") is None
def test_join_under_base_nonexistent_base_returns_none():
assert join_under_base("/nonexistent/path/xyz", "a") is None
def test_join_under_base_single_part(temp_dir):
result = join_under_base(str(temp_dir), "child")
assert result == str((Path(temp_dir) / "child").resolve())
def test_join_under_base_multiple_parts(temp_dir):
result = join_under_base(str(temp_dir), "a", "b")
assert result is not None
assert result == str((Path(temp_dir) / "a" / "b").resolve())
def test_join_under_base_traversal_part_returns_none(temp_dir):
assert join_under_base(str(temp_dir), "..", "etc") is None
def test_join_under_base_none_part_returns_none(temp_dir):
assert join_under_base(str(temp_dir), None) is None # type: ignore[arg-type]
def test_is_safe_path_empty_path_returns_false(temp_dir):
assert is_safe_path("", str(temp_dir)) is False
def test_is_safe_path_empty_base_returns_false(temp_dir):
assert is_safe_path(str(temp_dir), "") is False
def test_is_safe_path_subdir_returns_true(temp_dir):
sub = temp_dir / "sub"
sub.mkdir()
assert is_safe_path(str(sub), str(temp_dir)) is True
def test_is_safe_path_base_itself_returns_true(temp_dir):
assert is_safe_path(str(temp_dir), str(temp_dir)) is True
def test_is_safe_path_escapes_base_returns_false(temp_dir):
assert is_safe_path("/etc/passwd", str(temp_dir)) is False
def test_is_safe_path_sibling_dir_returns_false(tmp_path):
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_a.mkdir()
dir_b.mkdir()
assert is_safe_path(str(dir_b), str(dir_a)) is False
def test_list_files_empty_path_returns_empty():
assert list_files("") == []
def test_list_files_whitespace_path_returns_empty():
assert list_files(" ") == []
def test_list_files_nonexistent_path_returns_empty():
assert list_files("/nonexistent/path/12345") == []
def test_list_files_lists_files_recursively(temp_dir):
(temp_dir / "a.txt").write_text("a")
(temp_dir / "b.txt").write_text("b")
sub = temp_dir / "sub"
sub.mkdir()
(sub / "c.txt").write_text("c")
result = list_files(str(temp_dir), base=str(temp_dir))
names = [os.path.basename(p) for p in result]
assert "a.txt" in names
assert "b.txt" in names
assert "c.txt" in names
def test_list_files_skips_default_dirs(temp_dir):
(temp_dir / "keep.txt").write_text("x")
(temp_dir / "node_modules").mkdir()
(temp_dir / "__pycache__").mkdir()
result = list_files(str(temp_dir), base=str(temp_dir))
names = [os.path.basename(p) for p in result]
assert "keep.txt" in names
assert "node_modules" not in names
assert "__pycache__" not in names
def test_list_files_skips_default_extensions(temp_dir):
(temp_dir / "good.txt").write_text("x")
(temp_dir / "bad.pyc").write_bytes(b"")
(temp_dir / "bad.tmp").write_text("")
result = list_files(str(temp_dir), base=str(temp_dir))
names = [os.path.basename(p) for p in result]
assert "good.txt" in names
assert "bad.pyc" not in names
assert "bad.tmp" not in names
def test_list_files_skips_dotfiles(temp_dir):
(temp_dir / "visible.txt").write_text("x")
(temp_dir / ".hidden").write_text("h")
result = list_files(str(temp_dir), base=str(temp_dir))
names = [os.path.basename(p) for p in result]
assert "visible.txt" in names
assert ".hidden" not in names
def test_list_files_respects_max_entries(temp_dir):
for i in range(10):
(temp_dir / f"file{i}.txt").write_text(str(i))
result = list_files(str(temp_dir), base=str(temp_dir), max_entries=3)
assert len(result) == 3
def test_list_files_dir_path_is_file_returns_empty(temp_dir):
f = temp_dir / "file.txt"
f.write_text("x")
result = list_files(str(f), base=str(temp_dir))
assert result == []
def test_list_files_escaping_base_returns_empty(temp_dir):
parent = str(temp_dir.parent)
result = list_files(parent, base=str(temp_dir))
assert result == []
def test_list_files_default_skip_dirs_constant():
assert ".git" in DEFAULT_SKIP_DIRS
assert "node_modules" in DEFAULT_SKIP_DIRS
assert "venv" in DEFAULT_SKIP_DIRS
assert "__pycache__" in DEFAULT_SKIP_DIRS
assert ".venv" in DEFAULT_SKIP_DIRS
def test_get_working_directory_uses_new_folder_path(temp_dir):
options = MagicMock()
options.file_save_path.return_value = "/default"
task_lock = MagicMock()
task_lock.new_folder_path = str(temp_dir)
result = get_working_directory(options, task_lock)
assert os.path.isdir(result)
assert os.path.realpath(result) == os.path.realpath(str(temp_dir))
def test_get_working_directory_falls_back_to_file_save_path():
options = MagicMock()
options.file_save_path.return_value = os.path.expanduser("~")
result = get_working_directory(options, task_lock=None)
assert os.path.isdir(result)
def test_get_working_directory_no_new_folder_path_uses_env(temp_dir):
options = MagicMock()
task_lock = MagicMock()
task_lock.new_folder_path = None
with patch("app.utils.file_utils.env", return_value=str(temp_dir)):
result = get_working_directory(options, task_lock)
assert os.path.realpath(result) == os.path.realpath(str(temp_dir))
def test_get_working_directory_task_lock_without_attribute(temp_dir):
options = MagicMock()
task_lock = MagicMock(spec=[]) # no new_folder_path attribute
with patch("app.utils.file_utils.env", return_value=str(temp_dir)):
result = get_working_directory(options, task_lock)
assert os.path.realpath(result) == os.path.realpath(str(temp_dir))
def test_normalize_working_path_tilde_expands():
result = normalize_working_path("~")
assert os.path.isabs(result)
assert os.path.isdir(result)
assert result == os.path.expanduser("~")
def test_resolve_under_base_file_as_base_raises(temp_dir):
f = temp_dir / "file.txt"
f.write_text("x")
with pytest.raises(PathEscapesBaseError, match="not a directory"):
resolve_under_base("sub", str(f))
def test_join_under_base_no_parts_returns_base(temp_dir):
result = join_under_base(str(temp_dir))
assert result == str(Path(temp_dir).resolve())
def test_is_safe_path_file_under_base_returns_true(temp_dir):
f = temp_dir / "file.txt"
f.write_text("x")
assert is_safe_path(str(f), str(temp_dir)) is True
def test_list_files_custom_skip_dirs(temp_dir):
(temp_dir / "keep.txt").write_text("x")
custom = temp_dir / "custom_skip"
custom.mkdir()
(custom / "inside.txt").write_text("y")
result = list_files(
str(temp_dir), base=str(temp_dir), skip_dirs={"custom_skip"}
)
names = [os.path.basename(p) for p in result]
assert "keep.txt" in names
assert "inside.txt" not in names
def test_list_files_custom_skip_extensions(temp_dir):
(temp_dir / "keep.txt").write_text("x")
(temp_dir / "skip.log").write_text("y")
result = list_files(
str(temp_dir), base=str(temp_dir), skip_extensions=(".log",)
)
names = [os.path.basename(p) for p in result]
assert "keep.txt" in names
assert "skip.log" not in names
def test_list_files_custom_skip_prefix(temp_dir):
(temp_dir / "keep.txt").write_text("x")
(temp_dir / "_private.txt").write_text("y")
result = list_files(str(temp_dir), base=str(temp_dir), skip_prefix="_")
names = [os.path.basename(p) for p in result]
assert "keep.txt" in names
assert "_private.txt" not in names