Refactor/ingestion (#209)

Co-authored-by: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com>
This commit is contained in:
Romain Courtois 2025-03-04 01:11:54 +01:00 committed by GitHub
parent c96a7d3d48
commit d6cb920660
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1199 additions and 1487 deletions

3
.gitignore vendored
View file

@ -173,3 +173,6 @@ Caddyfile
# ignore default output directory
tmp/*
# Gitingest
digest.txt

View file

@ -95,6 +95,7 @@ repos:
files: ^src/
additional_dependencies:
[
chardet,
click,
fastapi-analytics,
pytest-asyncio,
@ -112,6 +113,7 @@ repos:
- --rcfile=tests/.pylintrc
additional_dependencies:
[
chardet,
click,
fastapi-analytics,
pytest,

View file

@ -142,7 +142,7 @@ Gitingest aims to be friendly for first time contributors, with a simple Python
- [tiktoken](https://github.com/openai/tiktoken) - Token estimation
- [posthog](https://github.com/PostHog/posthog) - Amazing analytics
### Looking for a JavaScript/Node package?
### Looking for a JavaScript/FileSystemNode package?
Check out the NPM alternative 📦 Repomix: <https://github.com/yamadashy/repomix>

View file

@ -1,12 +1,13 @@
[project]
name = "gitingest"
version = "0.1.3"
version = "0.1.4"
description="CLI tool to analyze and create text dumps of codebases for LLMs"
readme = {file = "README.md", content-type = "text/markdown" }
requires-python = ">= 3.8"
dependencies = [
"click>=8.0.0",
"tiktoken",
"tomli",
"typing_extensions; python_version < '3.10'",
]
@ -52,6 +53,7 @@ disable = [
"too-few-public-methods",
"broad-exception-caught",
"duplicate-code",
"fixme",
]
[tool.pycln]

View file

@ -1,3 +1,4 @@
chardet
click>=8.0.0
fastapi[standard]
python-dotenv

View file

@ -1,8 +1,8 @@
""" Gitingest: A package for ingesting data from Git repositories. """
from gitingest.query_ingestion import run_ingest_query
from gitingest.query_parser import parse_query
from gitingest.repository_clone import clone_repo
from gitingest.cloning import clone_repo
from gitingest.ingestion import ingest_query
from gitingest.query_parsing import parse_query
from gitingest.repository_ingest import ingest, ingest_async
__all__ = ["run_ingest_query", "clone_repo", "parse_query", "ingest", "ingest_async"]
__all__ = ["ingest_query", "clone_repo", "parse_query", "ingest", "ingest_async"]

View file

@ -7,7 +7,7 @@ from typing import Optional, Tuple
import click
from gitingest.config import MAX_FILE_SIZE, OUTPUT_FILE_PATH
from gitingest.config import MAX_FILE_SIZE, OUTPUT_FILE_NAME
from gitingest.repository_ingest import ingest_async
@ -92,15 +92,15 @@ async def _async_main(
include_patterns = set(include_pattern)
if not output:
output = OUTPUT_FILE_PATH
output = OUTPUT_FILE_NAME
summary, _, _ = await ingest_async(source, max_size, include_patterns, exclude_patterns, branch, output=output)
click.echo(f"Analysis complete! Output written to: {output}")
click.echo("\nSummary:")
click.echo(summary)
except Exception as e:
click.echo(f"Error: {e}", err=True)
except Exception as exc:
click.echo(f"Error: {exc}", err=True)
raise click.Abort()

View file

@ -6,7 +6,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from gitingest.utils import async_timeout
from gitingest.utils.timeout_wrapper import async_timeout
TIMEOUT: int = 60
@ -38,6 +38,7 @@ class CloneConfig:
commit: Optional[str] = None
branch: Optional[str] = None
subpath: str = "/"
blob: bool = False
@async_timeout(TIMEOUT)
@ -72,14 +73,15 @@ async def clone_repo(config: CloneConfig) -> None:
parent_dir = Path(local_path).parent
try:
os.makedirs(parent_dir, exist_ok=True)
except OSError as e:
raise OSError(f"Failed to create parent directory {parent_dir}: {e}") from e
except OSError as exc:
raise OSError(f"Failed to create parent directory {parent_dir}: {exc}") from exc
# Check if the repository exists
if not await _check_repo_exists(url):
raise ValueError("Repository not found, make sure it is public")
clone_cmd = ["git", "clone", "--recurse-submodules", "--single-branch"]
clone_cmd = ["git", "clone", "--single-branch"]
# TODO re-enable --recurse-submodules
if partial_clone:
clone_cmd += ["--filter=blob:none", "--sparse"]
@ -98,7 +100,10 @@ async def clone_repo(config: CloneConfig) -> None:
checkout_cmd = ["git", "-C", local_path]
if partial_clone:
checkout_cmd += ["sparse-checkout", "set", config.subpath.lstrip("/")]
if config.blob:
checkout_cmd += ["sparse-checkout", "set", config.subpath.lstrip("/")[:-1]]
else:
checkout_cmd += ["sparse-checkout", "set", config.subpath.lstrip("/")]
if commit:
checkout_cmd += ["checkout", commit]
@ -149,7 +154,6 @@ async def _check_repo_exists(url: str) -> bool:
raise RuntimeError(f"Unexpected status code: {status_code}")
@async_timeout(TIMEOUT)
async def fetch_remote_branch_list(url: str) -> List[str]:
"""
Fetch the list of branches from a remote Git repository.

View file

@ -8,6 +8,6 @@ MAX_DIRECTORY_DEPTH = 20 # Maximum depth of directory traversal
MAX_FILES = 10_000 # Maximum number of files to process
MAX_TOTAL_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB
OUTPUT_FILE_PATH = "digest.txt"
OUTPUT_FILE_NAME = "digest.txt"
TMP_BASE_PATH = Path(tempfile.gettempdir()) / "gitingest"

View file

@ -0,0 +1,143 @@
""" Define the schema for the filesystem representation. """
from __future__ import annotations
import os
from dataclasses import dataclass, field
from enum import Enum, auto
from pathlib import Path
from gitingest.exceptions import InvalidNotebookError
from gitingest.utils.ingestion_utils import _get_encoding_list
from gitingest.utils.notebook_utils import process_notebook
from gitingest.utils.textfile_checker_utils import is_textfile
SEPARATOR = "=" * 48 + "\n"
class FileSystemNodeType(Enum):
"""Enum representing the type of a file system node (directory or file)."""
DIRECTORY = auto()
FILE = auto()
@dataclass
class FileSystemStats:
"""Class for tracking statistics during file system traversal."""
visited: set[Path] = field(default_factory=set)
total_files: int = 0
total_size: int = 0
@dataclass
class FileSystemNode: # pylint: disable=too-many-instance-attributes
"""
Class representing a node in the file system (either a file or directory).
This class has more than the recommended number of attributes because it needs to
track various properties of files and directories for comprehensive analysis.
"""
name: str
type: FileSystemNodeType # e.g., "directory" or "file"
path_str: str
path: Path
size: int = 0
file_count: int = 0
dir_count: int = 0
depth: int = 0
children: list[FileSystemNode] = field(default_factory=list) # Using default_factory instead of empty list
def sort_children(self) -> None:
"""
Sort the children nodes of a directory according to a specific order.
Order of sorting:
1. README.md first
2. Regular files (not starting with dot)
3. Hidden files (starting with dot)
4. Regular directories (not starting with dot)
5. Hidden directories (starting with dot)
All groups are sorted alphanumerically within themselves.
"""
# Separate files and directories
files = [child for child in self.children if child.type == FileSystemNodeType.FILE]
directories = [child for child in self.children if child.type == FileSystemNodeType.DIRECTORY]
# Find README.md
readme_files = [f for f in files if f.name.lower() == "readme.md"]
other_files = [f for f in files if f.name.lower() != "readme.md"]
# Separate hidden and regular files/directories
regular_files = [f for f in other_files if not f.name.startswith(".")]
hidden_files = [f for f in other_files if f.name.startswith(".")]
regular_dirs = [d for d in directories if not d.name.startswith(".")]
hidden_dirs = [d for d in directories if d.name.startswith(".")]
# Sort each group alphanumerically
regular_files.sort(key=lambda x: x.name)
hidden_files.sort(key=lambda x: x.name)
regular_dirs.sort(key=lambda x: x.name)
hidden_dirs.sort(key=lambda x: x.name)
self.children = readme_files + regular_files + hidden_files + regular_dirs + hidden_dirs
@property
def content_string(self) -> str:
"""
Return the content of the node as a string.
This property returns the content of the node as a string, including the path and content.
Returns
-------
str
A string representation of the node's content.
"""
content_repr = SEPARATOR
# Use forward slashes in output paths
content_repr += f"File: {str(self.path_str).replace(os.sep, '/')}\n"
content_repr += SEPARATOR
content_repr += f"{self.content}\n\n"
return content_repr
@property
def content(self) -> str: # pylint: disable=too-many-return-statements
"""
Read the content of a file.
This function attempts to open a file and read its contents using UTF-8 encoding.
If an error occurs during reading (e.g., file is not found or permission error),
it returns an error message.
Returns
-------
str
The content of the file, or an error message if the file could not be read.
"""
if self.type == FileSystemNodeType.FILE and not is_textfile(self.path):
return "[Non-text file]"
try:
if self.path.suffix == ".ipynb":
try:
return process_notebook(self.path)
except Exception as exc:
return f"Error processing notebook: {exc}"
for encoding in _get_encoding_list():
try:
with self.path.open(encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
except OSError as exc:
return f"Error reading file: {exc}"
return "Error: Unable to decode file with available encodings"
except (OSError, InvalidNotebookError) as exc:
return f"Error reading file: {exc}"

312
src/gitingest/ingestion.py Normal file
View file

@ -0,0 +1,312 @@
""" Functions to ingest and analyze a codebase directory or single file. """
import warnings
from pathlib import Path
from typing import Tuple
from gitingest.config import MAX_DIRECTORY_DEPTH, MAX_FILES, MAX_TOTAL_SIZE_BYTES
from gitingest.filesystem_schema import FileSystemNode, FileSystemNodeType, FileSystemStats
from gitingest.output_formatters import format_directory, format_single_file
from gitingest.query_parsing import ParsedQuery
from gitingest.utils.ingestion_utils import _should_exclude, _should_include
from gitingest.utils.path_utils import _is_safe_symlink
try:
import tomllib
except ImportError:
import tomli as tomllib
def ingest_query(query: ParsedQuery) -> Tuple[str, str, str]:
"""
Run the ingestion process for a parsed query.
This is the main entry point for analyzing a codebase directory or single file. It processes the query
parameters, reads the file or directory content, and generates a summary, directory structure, and file content,
along with token estimations.
Parameters
----------
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
Returns
-------
Tuple[str, str, str]
A tuple containing the summary, directory structure, and file contents.
Raises
------
ValueError
If the specified path cannot be found or if the file is not a text file.
"""
subpath = Path(query.subpath.strip("/")).as_posix()
path = query.local_path / subpath
apply_gitingest_file(path, query)
if not path.exists():
raise ValueError(f"{query.slug} cannot be found")
if (query.type and query.type == "blob") or query.local_path.is_file():
# TODO: We do this wrong! We should still check the branch and commit!
if not path.is_file():
raise ValueError(f"Path {path} is not a file")
relative_path = path.relative_to(query.local_path)
file_node = FileSystemNode(
name=path.name,
type=FileSystemNodeType.FILE,
size=path.stat().st_size,
file_count=1,
path_str=str(relative_path),
path=path,
)
return format_single_file(file_node, query)
root_node = FileSystemNode(
name=path.name,
type=FileSystemNodeType.DIRECTORY,
path_str=str(path.relative_to(query.local_path)),
path=path,
)
stats = FileSystemStats()
_process_node(
node=root_node,
query=query,
stats=stats,
)
return format_directory(root_node, query)
def apply_gitingest_file(path: Path, query: ParsedQuery) -> None:
"""
Apply the .gitingest file to the query object.
This function reads the .gitingest file in the specified path and updates the query object with the ignore
patterns found in the file.
Parameters
----------
path : Path
The path of the directory to ingest.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
It should have an attribute `ignore_patterns` which is either None or a set of strings.
"""
path_gitingest = path / ".gitingest"
if not path_gitingest.is_file():
return
try:
with path_gitingest.open("rb") as f:
data = tomllib.load(f)
except tomllib.TOMLDecodeError as exc:
warnings.warn(f"Invalid TOML in {path_gitingest}: {exc}", UserWarning)
return
config_section = data.get("config", {})
ignore_patterns = config_section.get("ignore_patterns")
if not ignore_patterns:
return
# If a single string is provided, make it a list of one element
if isinstance(ignore_patterns, str):
ignore_patterns = [ignore_patterns]
if not isinstance(ignore_patterns, (list, set)):
warnings.warn(
f"Expected a list/set for 'ignore_patterns', got {type(ignore_patterns)} in {path_gitingest}. Skipping.",
UserWarning,
)
return
# Filter out duplicated patterns
ignore_patterns = set(ignore_patterns)
# Filter out any non-string entries
valid_patterns = {pattern for pattern in ignore_patterns if isinstance(pattern, str)}
invalid_patterns = ignore_patterns - valid_patterns
if invalid_patterns:
warnings.warn(f"Ignore patterns {invalid_patterns} are not strings. Skipping.", UserWarning)
if not valid_patterns:
return
if query.ignore_patterns is None:
query.ignore_patterns = valid_patterns
else:
query.ignore_patterns.update(valid_patterns)
return
def _process_node(
node: FileSystemNode,
query: ParsedQuery,
stats: FileSystemStats,
) -> None:
"""
Process a file or directory item within a directory.
This function handles each file or directory item, checking if it should be included or excluded based on the
provided patterns. It handles symlinks, directories, and files accordingly.
Parameters
----------
node : FileSystemNode
The current directory or file node being processed.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
stats : FileSystemStats
Statistics tracking object for the total file count and size.
Raises
------
ValueError
If an unexpected error occurs during processing.
"""
if limit_exceeded(stats, node.depth):
return
for sub_path in node.path.iterdir():
symlink_path = None
if sub_path.is_symlink():
if not _is_safe_symlink(sub_path, query.local_path):
print(f"Skipping unsafe symlink: {sub_path}")
continue
symlink_path = sub_path
sub_path = sub_path.resolve()
if sub_path in stats.visited:
print(f"Skipping already visited path: {sub_path}")
continue
stats.visited.add(sub_path)
if query.ignore_patterns and _should_exclude(sub_path, query.local_path, query.ignore_patterns):
continue
if query.include_patterns and not _should_include(sub_path, query.local_path, query.include_patterns):
continue
if sub_path.is_file():
_process_file(path=sub_path, parent_node=node, stats=stats, local_path=query.local_path)
elif sub_path.is_dir():
child_directory_node = FileSystemNode(
name=sub_path.name,
type=FileSystemNodeType.DIRECTORY,
path_str=str(sub_path.relative_to(query.local_path)),
path=sub_path,
depth=node.depth + 1,
)
# rename the subdir to reflect the symlink name
if symlink_path:
child_directory_node.name = symlink_path.name
child_directory_node.path_str = str(symlink_path)
_process_node(
node=child_directory_node,
query=query,
stats=stats,
)
node.children.append(child_directory_node)
node.size += child_directory_node.size
node.file_count += child_directory_node.file_count
node.dir_count += 1 + child_directory_node.dir_count
else:
raise ValueError(f"Unexpected error: {sub_path} is neither a file nor a directory")
node.sort_children()
def _process_file(path: Path, parent_node: FileSystemNode, stats: FileSystemStats, local_path: Path) -> None:
"""
Process a file in the file system.
This function checks the file's size, increments the statistics, and reads its content.
If the file size exceeds the maximum allowed, it raises an error.
Parameters
----------
path : Path
The full path of the file.
parent_node : FileSystemNode
The dictionary to accumulate the results.
stats : FileSystemStats
Statistics tracking object for the total file count and size.
local_path : Path
The base path of the repository or directory being processed.
"""
file_size = path.stat().st_size
if stats.total_size + file_size > MAX_TOTAL_SIZE_BYTES:
print(f"Skipping file {path}: would exceed total size limit")
return
stats.total_files += 1
stats.total_size += file_size
if stats.total_files > MAX_FILES:
print(f"Maximum file limit ({MAX_FILES}) reached")
return
child = FileSystemNode(
name=path.name,
type=FileSystemNodeType.FILE,
size=file_size,
file_count=1,
path_str=str(path.relative_to(local_path)),
path=path,
depth=parent_node.depth + 1,
)
parent_node.children.append(child)
parent_node.size += file_size
parent_node.file_count += 1
def limit_exceeded(stats: FileSystemStats, depth: int) -> bool:
"""
Check if any of the traversal limits have been exceeded.
This function checks if the current traversal has exceeded any of the configured limits:
maximum directory depth, maximum number of files, or maximum total size in bytes.
Parameters
----------
stats : FileSystemStats
Statistics tracking object for the total file count and size.
depth : int
The current depth of directory traversal.
Returns
-------
bool
True if any limit has been exceeded, False otherwise.
"""
if depth > MAX_DIRECTORY_DEPTH:
print(f"Maximum depth limit ({MAX_DIRECTORY_DEPTH}) reached")
return True
if stats.total_files >= MAX_FILES:
print(f"Maximum file limit ({MAX_FILES}) reached")
return True # TODO: end recursion
if stats.total_size >= MAX_TOTAL_SIZE_BYTES:
print(f"Maxumum total size limit ({MAX_TOTAL_SIZE_BYTES/1024/1024:.1f}MB) reached")
return True # TODO: end recursion
return False

View file

@ -0,0 +1,210 @@
""" Functions to ingest and analyze a codebase directory or single file. """
from typing import Optional, Tuple
import tiktoken
from gitingest.filesystem_schema import FileSystemNode, FileSystemNodeType
from gitingest.query_parsing import ParsedQuery
def _create_summary_string(query: ParsedQuery, node: FileSystemNode) -> str:
"""
Create a summary string with file counts and content size.
This function generates a summary of the repository's contents, including the number
of files analyzed, the total content size, and other relevant details based on the query parameters.
Parameters
----------
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
node : FileSystemNode
The root node representing the directory structure, including file and directory counts.
Returns
-------
str
Summary string containing details such as repository name, file count, and other query-specific information.
"""
if query.user_name:
summary = f"Repository: {query.user_name}/{query.repo_name}\n"
else:
# Local scenario
summary = f"Directory: {query.slug}\n"
if query.commit:
summary += f"Commit: {query.commit}\n"
elif query.branch and query.branch not in ("main", "master"):
summary += f"Branch: {query.branch}\n"
if query.subpath != "/":
summary += f"Subpath: {query.subpath}\n"
summary += f"Files analyzed: {node.file_count}\n"
# TODO: Do we want to add the total number of lines?
return summary
def format_single_file(file_node: FileSystemNode, query: ParsedQuery) -> Tuple[str, str, str]:
"""
Format a single file for display.
This function generates a summary, tree structure, and content for a single file.
It includes information such as the repository name, commit/branch, file name,
line count, and estimated token count.
Parameters
----------
file_node : FileSystemNode
The node representing the file to format.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
Returns
-------
Tuple[str, str, str]
A tuple containing the summary, tree structure, and file content.
Raises
------
ValueError
If the file has no content.
"""
if not file_node.content:
raise ValueError(f"File {file_node.name} has no content")
summary = f"Repository: {query.user_name}/{query.repo_name}\n"
if query.commit:
summary += f"Commit: {query.commit}\n"
elif query.branch and query.branch not in ("main", "master"):
summary += f"Branch: {query.branch}\n"
summary += f"File: {file_node.name}\n"
summary += f"Lines: {len(file_node.content.splitlines()):,}\n"
files_content = file_node.content_string
tree = "Directory structure:\n└── " + file_node.name
formatted_tokens = _generate_token_string(files_content)
if formatted_tokens:
summary += f"\nEstimated tokens: {formatted_tokens}"
return summary, tree, files_content
def _get_files_content(node: FileSystemNode) -> str:
if node.type == FileSystemNodeType.FILE:
return node.content_string
if node.type == FileSystemNodeType.DIRECTORY:
return "\n".join(_get_files_content(child) for child in node.children)
return ""
def _create_tree_structure(query: ParsedQuery, node: FileSystemNode, prefix: str = "", is_last: bool = True) -> str:
"""
Create a tree-like string representation of the file structure.
This function generates a string representation of the directory structure, formatted
as a tree with appropriate indentation for nested directories and files.
Parameters
----------
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
node : FileSystemNode
The current directory or file node being processed.
prefix : str
A string used for indentation and formatting of the tree structure, by default "".
is_last : bool
A flag indicating whether the current node is the last in its directory, by default True.
Returns
-------
str
A string representing the directory structure formatted as a tree.
"""
tree = ""
if not node.name:
node.name = query.slug
if node.name:
current_prefix = "└── " if is_last else "├── "
name = node.name + "/" if node.type == FileSystemNodeType.DIRECTORY else node.name
tree += prefix + current_prefix + name + "\n"
if node.type == FileSystemNodeType.DIRECTORY:
# Adjust prefix only if we added a node name
new_prefix = prefix + (" " if is_last else "") if node.name else prefix
children = node.children
for i, child in enumerate(children):
tree += _create_tree_structure(query, node=child, prefix=new_prefix, is_last=i == len(children) - 1)
return tree
def _generate_token_string(context_string: str) -> Optional[str]:
"""
Return the number of tokens in a text string.
This function estimates the number of tokens in a given text string using the `tiktoken`
library. It returns the number of tokens in a human-readable format (e.g., '1.2k', '1.2M').
Parameters
----------
context_string : str
The text string for which the token count is to be estimated.
Returns
-------
str, optional
The formatted number of tokens as a string (e.g., '1.2k', '1.2M'), or `None` if an error occurs.
"""
try:
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = len(encoding.encode(context_string, disallowed_special=()))
except (ValueError, UnicodeEncodeError) as exc:
print(exc)
return None
if total_tokens > 1_000_000:
return f"{total_tokens / 1_000_000:.1f}M"
if total_tokens > 1_000:
return f"{total_tokens / 1_000:.1f}k"
return str(total_tokens)
def format_directory(root_node: FileSystemNode, query: ParsedQuery) -> Tuple[str, str, str]:
"""
Ingest an entire directory and return its summary, directory structure, and file contents.
This function processes a directory, extracts its contents, and generates a summary,
directory structure, and file content. It recursively processes subdirectories as well.
Parameters
----------
root_node : FileSystemNode
The root node representing the directory to process.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
Returns
-------
Tuple[str, str, str]
A tuple containing the summary, directory structure, and file contents.
"""
summary = _create_summary_string(query, node=root_node)
tree = "Directory structure:\n" + _create_tree_structure(query, root_node)
files_content = _get_files_content(root_node)
formatted_tokens = _generate_token_string(tree + files_content)
if formatted_tokens:
summary += f"\nEstimated tokens: {formatted_tokens}"
return summary, tree, files_content

View file

@ -1,970 +0,0 @@
""" Functions to ingest and analyze a codebase directory or single file. """
import locale
import os
import platform
import warnings
from fnmatch import fnmatch
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import tiktoken
import tomli
from gitingest.config import MAX_DIRECTORY_DEPTH, MAX_FILES, MAX_TOTAL_SIZE_BYTES
from gitingest.exceptions import (
AlreadyVisitedError,
InvalidNotebookError,
MaxFileSizeReachedError,
MaxFilesReachedError,
)
from gitingest.notebook_utils import process_notebook
from gitingest.query_parser import ParsedQuery
try:
locale.setlocale(locale.LC_ALL, "")
except locale.Error:
locale.setlocale(locale.LC_ALL, "C")
def _normalize_path(path: Path) -> Path:
"""
Normalize path for cross-platform compatibility.
Parameters
----------
path : Path
The Path object to normalize.
Returns
-------
Path
The normalized path with platform-specific separators and resolved components.
"""
return Path(os.path.normpath(str(path)))
def _normalize_path_str(path: Union[Path, str]) -> str:
"""
Convert path to string with forward slashes for consistent output.
Parameters
----------
path : str | Path
The path to convert, can be string or Path object.
Returns
-------
str
The normalized path string with forward slashes as separators.
"""
return str(path).replace(os.sep, "/")
def _get_encoding_list() -> List[str]:
"""
Get list of encodings to try, prioritized for the current platform.
Returns
-------
List[str]
List of encoding names to try in priority order, starting with the
platform's default encoding followed by common fallback encodings.
"""
encodings = ["utf-8", "utf-8-sig", "latin"]
if platform.system() == "Windows":
encodings.extend(["cp1252", "iso-8859-1"])
return encodings + [locale.getpreferredencoding()]
def _should_include(path: Path, base_path: Path, include_patterns: Set[str]) -> bool:
"""
Determine if the given file or directory path matches any of the include patterns.
This function checks whether the relative path of a file or directory matches any of the specified patterns. If a
match is found, it returns `True`, indicating that the file or directory should be included in further processing.
Parameters
----------
path : Path
The absolute path of the file or directory to check.
base_path : Path
The base directory from which the relative path is calculated.
include_patterns : Set[str]
A set of patterns to check against the relative path.
Returns
-------
bool
`True` if the path matches any of the include patterns, `False` otherwise.
"""
try:
rel_path = path.relative_to(base_path)
except ValueError:
# If path is not under base_path at all
return False
rel_str = str(rel_path)
for pattern in include_patterns:
if fnmatch(rel_str, pattern):
return True
return False
def _should_exclude(path: Path, base_path: Path, ignore_patterns: Set[str]) -> bool:
"""
Determine if the given file or directory path matches any of the ignore patterns.
This function checks whether the relative path of a file or directory matches
any of the specified ignore patterns. If a match is found, it returns `True`, indicating
that the file or directory should be excluded from further processing.
Parameters
----------
path : Path
The absolute path of the file or directory to check.
base_path : Path
The base directory from which the relative path is calculated.
ignore_patterns : Set[str]
A set of patterns to check against the relative path.
Returns
-------
bool
`True` if the path matches any of the ignore patterns, `False` otherwise.
"""
try:
rel_path = path.relative_to(base_path)
except ValueError:
# If path is not under base_path at all
return True
rel_str = str(rel_path)
for pattern in ignore_patterns:
if pattern and fnmatch(rel_str, pattern):
return True
return False
def _is_safe_symlink(symlink_path: Path, base_path: Path) -> bool:
"""
Check if a symlink points to a location within the base directory.
This function resolves the target of a symlink and ensures it is within the specified
base directory, returning `True` if it is safe, or `False` if the symlink points outside
the base directory.
Parameters
----------
symlink_path : Path
The path of the symlink to check.
base_path : Path
The base directory to ensure the symlink points within.
Returns
-------
bool
`True` if the symlink points within the base directory, `False` otherwise.
"""
try:
if platform.system() == "Windows":
if not os.path.islink(str(symlink_path)):
return False
target_path = _normalize_path(symlink_path.resolve())
base_resolved = _normalize_path(base_path.resolve())
return base_resolved in target_path.parents or target_path == base_resolved
except (OSError, ValueError):
# If there's any error resolving the paths, consider it unsafe
return False
def _is_text_file(file_path: Path) -> bool:
"""
Determine if a file is likely a text file based on its content.
This function attempts to read the first 1024 bytes of a file and checks for the presence
of non-text characters. It returns `True` if the file is determined to be a text file,
otherwise returns `False`.
Parameters
----------
file_path : Path
The path to the file to check.
Returns
-------
bool
`True` if the file is likely a text file, `False` otherwise.
"""
try:
with file_path.open("rb") as file:
chunk = file.read(1024)
return not bool(chunk.translate(None, bytes([7, 8, 9, 10, 12, 13, 27] + list(range(0x20, 0x100)))))
except OSError:
return False
def _read_file_content(file_path: Path) -> str:
"""
Read the content of a file.
This function attempts to open a file and read its contents using UTF-8 encoding.
If an error occurs during reading (e.g., file is not found or permission error),
it returns an error message.
Parameters
----------
file_path : Path
The path to the file to read.
Returns
-------
str
The content of the file, or an error message if the file could not be read.
"""
try:
if file_path.suffix == ".ipynb":
try:
return process_notebook(file_path)
except Exception as e:
return f"Error processing notebook: {e}"
for encoding in _get_encoding_list():
try:
with open(file_path, encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
except OSError as e:
return f"Error reading file: {e}"
return "Error: Unable to decode file with available encodings"
except (OSError, InvalidNotebookError) as e:
return f"Error reading file: {e}"
def _sort_children(children: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Sort the children nodes of a directory according to a specific order.
Order of sorting:
1. README.md first
2. Regular files (not starting with dot)
3. Hidden files (starting with dot)
4. Regular directories (not starting with dot)
5. Hidden directories (starting with dot)
All groups are sorted alphanumerically within themselves.
Parameters
----------
children : List[Dict[str, Any]]
List of file and directory nodes to sort.
Returns
-------
List[Dict[str, Any]]
Sorted list according to the specified order.
"""
# Separate files and directories
files = [child for child in children if child["type"] == "file"]
directories = [child for child in children if child["type"] == "directory"]
# Find README.md
readme_files = [f for f in files if f["name"].lower() == "readme.md"]
other_files = [f for f in files if f["name"].lower() != "readme.md"]
# Separate hidden and regular files/directories
regular_files = [f for f in other_files if not f["name"].startswith(".")]
hidden_files = [f for f in other_files if f["name"].startswith(".")]
regular_dirs = [d for d in directories if not d["name"].startswith(".")]
hidden_dirs = [d for d in directories if d["name"].startswith(".")]
# Sort each group alphanumerically
regular_files.sort(key=lambda x: x["name"])
hidden_files.sort(key=lambda x: x["name"])
regular_dirs.sort(key=lambda x: x["name"])
hidden_dirs.sort(key=lambda x: x["name"])
# Combine all groups in the desired order
return readme_files + regular_files + hidden_files + regular_dirs + hidden_dirs
def _scan_directory(
path: Path,
query: ParsedQuery,
seen_paths: Optional[Set[Path]] = None,
depth: int = 0,
stats: Optional[Dict[str, int]] = None,
) -> Optional[Dict[str, Any]]:
"""
Recursively analyze a directory and its contents with safety limits.
This function scans a directory and its subdirectories up to a specified depth. It checks
for any file or directory that should be included or excluded based on the provided patterns
and limits. It also tracks the number of files and total size processed.
Parameters
----------
path : Path
The path of the directory to scan.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
seen_paths : Set[Path] | None, optional
A set to track already visited paths, by default None.
depth : int
The current depth of directory traversal, by default 0.
stats : Dict[str, int] | None, optional
A dictionary to track statistics such as total file count and size, by default None.
Returns
-------
Dict[str, Any] | None
A dictionary representing the directory structure and contents, or `None` if limits are reached.
"""
if seen_paths is None:
seen_paths = set()
if stats is None:
stats = {"total_files": 0, "total_size": 0}
if depth > MAX_DIRECTORY_DEPTH:
print(f"Skipping deep directory: {path} (max depth {MAX_DIRECTORY_DEPTH} reached)")
return None
if stats["total_files"] >= MAX_FILES:
print(f"Skipping further processing: maximum file limit ({MAX_FILES}) reached")
return None
if stats["total_size"] >= MAX_TOTAL_SIZE_BYTES:
print(f"Skipping further processing: maximum total size ({MAX_TOTAL_SIZE_BYTES/1024/1024:.1f}MB) reached")
return None
real_path = path.resolve()
if real_path in seen_paths:
print(f"Skipping already visited path: {path}")
return None
seen_paths.add(real_path)
result = {
"name": path.name,
"type": "directory",
"size": 0,
"children": [],
"file_count": 0,
"dir_count": 0,
"path": str(path),
"ignore_content": False,
}
try:
for item in path.iterdir():
_process_item(item=item, query=query, result=result, seen_paths=seen_paths, stats=stats, depth=depth)
except MaxFilesReachedError:
print(f"Maximum file limit ({MAX_FILES}) reached.")
except PermissionError:
print(f"Permission denied: {path}.")
result["children"] = _sort_children(result["children"])
return result
def _process_symlink(
item: Path,
query: ParsedQuery,
result: Dict[str, Any],
seen_paths: Set[Path],
stats: Dict[str, int],
depth: int,
) -> None:
"""
Process a symlink in the file system.
This function checks if a symlink is safe, resolves its target, and processes it accordingly.
If the symlink is not safe, an exception is raised.
Parameters
----------
item : Path
The full path of the symlink.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
result : Dict[str, Any]
The dictionary to accumulate the results.
seen_paths : Set[str]
A set of already visited paths.
stats : Dict[str, int]
The dictionary to track statistics such as file count and size.
depth : int
The current depth in the directory traversal.
Raises
------
AlreadyVisitedError
If the symlink has already been processed.
MaxFileSizeReachedError
If the file size exceeds the maximum limit.
MaxFilesReachedError
If the number of files exceeds the maximum limit.
"""
if not _is_safe_symlink(item, query.local_path):
raise AlreadyVisitedError(str(item))
real_path = item.resolve()
if real_path in seen_paths:
raise AlreadyVisitedError(str(item))
if real_path.is_file():
file_size = real_path.stat().st_size
if stats["total_size"] + file_size > MAX_TOTAL_SIZE_BYTES:
raise MaxFileSizeReachedError(MAX_TOTAL_SIZE_BYTES)
stats["total_files"] += 1
stats["total_size"] += file_size
if stats["total_files"] > MAX_FILES:
print(f"Maximum file limit ({MAX_FILES}) reached")
raise MaxFilesReachedError(MAX_FILES)
is_text = _is_text_file(real_path)
content = _read_file_content(real_path) if is_text else "[Non-text file]"
child = {
"name": item.name,
"type": "file",
"size": file_size,
"content": content,
"path": str(item),
}
result["children"].append(child)
result["size"] += file_size
result["file_count"] += 1
elif real_path.is_dir():
subdir = _scan_directory(
path=real_path,
query=query,
seen_paths=seen_paths,
depth=depth + 1,
stats=stats,
)
if subdir and (not query.include_patterns or subdir["file_count"] > 0):
# rename the subdir to reflect the symlink name
subdir["name"] = item.name
subdir["path"] = str(item)
result["children"].append(subdir)
result["size"] += subdir["size"]
result["file_count"] += subdir["file_count"]
result["dir_count"] += 1 + subdir["dir_count"]
def _process_file(item: Path, result: Dict[str, Any], stats: Dict[str, int]) -> None:
"""
Process a file in the file system.
This function checks the file's size, increments the statistics, and reads its content.
If the file size exceeds the maximum allowed, it raises an error.
Parameters
----------
item : Path
The full path of the file.
result : Dict[str, Any]
The dictionary to accumulate the results.
stats : Dict[str, int]
The dictionary to track statistics such as file count and size.
Raises
------
MaxFileSizeReachedError
If the file size exceeds the maximum limit.
MaxFilesReachedError
If the number of files exceeds the maximum limit.
"""
file_size = item.stat().st_size
if stats["total_size"] + file_size > MAX_TOTAL_SIZE_BYTES:
print(f"Skipping file {item}: would exceed total size limit")
raise MaxFileSizeReachedError(MAX_TOTAL_SIZE_BYTES)
stats["total_files"] += 1
stats["total_size"] += file_size
if stats["total_files"] > MAX_FILES:
print(f"Maximum file limit ({MAX_FILES}) reached")
raise MaxFilesReachedError(MAX_FILES)
is_text = _is_text_file(item)
content = _read_file_content(item) if is_text else "[Non-text file]"
child = {
"name": item.name,
"type": "file",
"size": file_size,
"content": content,
"path": str(item),
}
result["children"].append(child)
result["size"] += file_size
result["file_count"] += 1
def _process_item(
item: Path,
query: ParsedQuery,
result: Dict[str, Any],
seen_paths: Set[Path],
stats: Dict[str, int],
depth: int,
) -> None:
"""
Process a file or directory item within a directory.
This function handles each file or directory item, checking if it should be included or excluded based on the
provided patterns. It handles symlinks, directories, and files accordingly.
Parameters
----------
item : Path
The full path of the file or directory to process.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
result : Dict[str, Any]
The result dictionary to accumulate processed file/directory data.
seen_paths : Set[Path]
A set of paths that have already been visited.
stats : Dict[str, int]
A dictionary of statistics like the total file count and size.
depth : int
The current depth of directory traversal.
"""
if not query.ignore_patterns or _should_exclude(item, query.local_path, query.ignore_patterns):
return
if (
item.is_file()
and query.include_patterns
and not _should_include(item, query.local_path, query.include_patterns)
):
result["ignore_content"] = True
return
try:
if item.is_symlink():
_process_symlink(item=item, query=query, result=result, seen_paths=seen_paths, stats=stats, depth=depth)
if item.is_file():
_process_file(item=item, result=result, stats=stats)
elif item.is_dir():
subdir = _scan_directory(path=item, query=query, seen_paths=seen_paths, depth=depth + 1, stats=stats)
if subdir and (not query.include_patterns or subdir["file_count"] > 0):
result["children"].append(subdir)
result["size"] += subdir["size"]
result["file_count"] += subdir["file_count"]
result["dir_count"] += 1 + subdir["dir_count"]
except (MaxFileSizeReachedError, AlreadyVisitedError) as e:
print(e)
def _extract_files_content(
query: ParsedQuery,
node: Dict[str, Any],
files: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, Any]]:
"""
Recursively collect all text files with their contents.
This function traverses the directory tree and extracts the contents of all text files
into a list, ignoring non-text files or files that exceed the specified size limit.
Parameters
----------
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
node : Dict[str, Any]
The current directory or file node being processed.
files : List[Dict[str, Any]] | None, optional
A list to collect the extracted files' information, by default None.
Returns
-------
List[Dict[str, Any]]
A list of dictionaries, each containing the path, content (or `None` if too large), and size of each file.
"""
if files is None:
files = []
if node["type"] == "file" and node["content"] != "[Non-text file]":
if node["size"] > query.max_file_size:
content = None
else:
content = node["content"]
relative_path = Path(node["path"]).relative_to(query.local_path)
# Store paths with forward slashes
files.append(
{
"path": _normalize_path_str(relative_path),
"content": content,
"size": node["size"],
},
)
elif node["type"] == "directory":
for child in node["children"]:
_extract_files_content(query=query, node=child, files=files)
return files
def _create_file_content_string(files: List[Dict[str, Any]]) -> str:
"""
Create a formatted string of file contents with separators.
This function takes a list of files and generates a formatted string where each file's
content is separated by a divider.
Parameters
----------
files : List[Dict[str, Any]]
A list of dictionaries containing file information, including the path and content.
Returns
-------
str
A formatted string representing the contents of all the files with appropriate separators.
"""
output = ""
separator = "=" * 48 + "\n"
# Then add all other files in their original order
for file in files:
if not file["content"]:
continue
output += separator
# Use forward slashes in output paths
output += f"File: {_normalize_path_str(file['path'])}\n"
output += separator
output += f"{file['content']}\n\n"
return output
def _create_summary_string(query: ParsedQuery, nodes: Dict[str, Any]) -> str:
"""
Create a summary string with file counts and content size.
This function generates a summary of the repository's contents, including the number
of files analyzed, the total content size, and other relevant details based on the query parameters.
Parameters
----------
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
nodes : Dict[str, Any]
Dictionary representing the directory structure, including file and directory counts.
Returns
-------
str
Summary string containing details such as repository name, file count, and other query-specific information.
"""
if query.user_name:
summary = f"Repository: {query.user_name}/{query.repo_name}\n"
else:
summary = f"Repository: {query.slug}\n"
summary += f"Files analyzed: {nodes['file_count']}\n"
if query.subpath != "/":
summary += f"Subpath: {query.subpath}\n"
if query.commit:
summary += f"Commit: {query.commit}\n"
elif query.branch and query.branch not in ("main", "master"):
summary += f"Branch: {query.branch}\n"
return summary
def _create_tree_structure(query: ParsedQuery, node: Dict[str, Any], prefix: str = "", is_last: bool = True) -> str:
"""
Create a tree-like string representation of the file structure.
This function generates a string representation of the directory structure, formatted
as a tree with appropriate indentation for nested directories and files.
Parameters
----------
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
node : Dict[str, Any]
The current directory or file node being processed.
prefix : str
A string used for indentation and formatting of the tree structure, by default "".
is_last : bool
A flag indicating whether the current node is the last in its directory, by default True.
Returns
-------
str
A string representing the directory structure formatted as a tree.
"""
tree = ""
if not node["name"]:
node["name"] = query.slug
if node["name"]:
current_prefix = "└── " if is_last else "├── "
name = node["name"] + "/" if node["type"] == "directory" else node["name"]
tree += prefix + current_prefix + name + "\n"
if node["type"] == "directory":
# Adjust prefix only if we added a node name
new_prefix = prefix + (" " if is_last else "") if node["name"] else prefix
children = node["children"]
for i, child in enumerate(children):
tree += _create_tree_structure(query, child, new_prefix, i == len(children) - 1)
return tree
def _generate_token_string(context_string: str) -> Optional[str]:
"""
Return the number of tokens in a text string.
This function estimates the number of tokens in a given text string using the `tiktoken`
library. It returns the number of tokens in a human-readable format (e.g., '1.2k', '1.2M').
Parameters
----------
context_string : str
The text string for which the token count is to be estimated.
Returns
-------
str, optional
The formatted number of tokens as a string (e.g., '1.2k', '1.2M'), or `None` if an error occurs.
"""
try:
encoding = tiktoken.get_encoding("cl100k_base")
total_tokens = len(encoding.encode(context_string, disallowed_special=()))
except (ValueError, UnicodeEncodeError) as e:
print(e)
return None
if total_tokens > 1_000_000:
return f"{total_tokens / 1_000_000:.1f}M"
if total_tokens > 1_000:
return f"{total_tokens / 1_000:.1f}k"
return str(total_tokens)
def _ingest_single_file(path: Path, query: ParsedQuery) -> Tuple[str, str, str]:
"""
Ingest a single file and return its summary, directory structure, and content.
This function reads a file, generates a summary of its contents, and returns the content
along with its directory structure and token estimation.
Parameters
----------
path : Path
The path of the file to ingest.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
Returns
-------
Tuple[str, str, str]
A tuple containing the summary, directory structure, and file content.
Raises
------
ValueError
If the specified path is not a file or if the file is not a text file.
"""
if not path.is_file():
raise ValueError(f"Path {path} is not a file")
if not _is_text_file(path):
raise ValueError(f"File {path} is not a text file")
file_size = path.stat().st_size
if file_size > query.max_file_size:
content = "[Content ignored: file too large]"
else:
content = _read_file_content(path)
relative_path = path.relative_to(query.local_path)
file_info = {
"path": str(relative_path),
"content": content,
"size": file_size,
}
summary = (
f"Repository: {query.user_name}/{query.repo_name}\n"
f"File: {path.name}\n"
f"Size: {file_size:,} bytes\n"
f"Lines: {len(content.splitlines()):,}\n"
)
files_content = _create_file_content_string([file_info])
tree = "Directory structure:\n└── " + path.name
formatted_tokens = _generate_token_string(files_content)
if formatted_tokens:
summary += f"\nEstimated tokens: {formatted_tokens}"
return summary, tree, files_content
def _ingest_directory(path: Path, query: ParsedQuery) -> Tuple[str, str, str]:
"""
Ingest an entire directory and return its summary, directory structure, and file contents.
This function processes a directory, extracts its contents, and generates a summary,
directory structure, and file content. It recursively processes subdirectories as well.
Parameters
----------
path : Path
The path of the directory to ingest.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
Returns
-------
Tuple[str, str, str]
A tuple containing the summary, directory structure, and file contents.
Raises
------
ValueError
If no files are found in the directory.
"""
nodes = _scan_directory(path=path, query=query)
if not nodes:
raise ValueError(f"No files found in {path}")
files = _extract_files_content(query=query, node=nodes)
summary = _create_summary_string(query, nodes)
tree = "Directory structure:\n" + _create_tree_structure(query, nodes)
files_content = _create_file_content_string(files)
formatted_tokens = _generate_token_string(tree + files_content)
if formatted_tokens:
summary += f"\nEstimated tokens: {formatted_tokens}"
return summary, tree, files_content
def run_ingest_query(query: ParsedQuery) -> Tuple[str, str, str]:
"""
Run the ingestion process for a parsed query.
This is the main entry point for analyzing a codebase directory or single file. It processes the query
parameters, reads the file or directory content, and generates a summary, directory structure, and file content,
along with token estimations.
Parameters
----------
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
Returns
-------
Tuple[str, str, str]
A tuple containing the summary, directory structure, and file contents.
Raises
------
ValueError
If the specified path cannot be found or if the file is not a text file.
"""
subpath = _normalize_path(Path(query.subpath.strip("/"))).as_posix()
path = _normalize_path(query.local_path / subpath)
if not path.exists():
raise ValueError(f"{query.slug} cannot be found")
if query.type and query.type == "blob":
return _ingest_single_file(path, query)
apply_gitingest_file(path, query)
return _ingest_directory(path, query)
def apply_gitingest_file(path: Path, query: ParsedQuery) -> None:
"""
Apply the .gitingest file to the query object.
This function reads the .gitingest file in the specified path and updates the query object with the ignore
patterns found in the file.
Parameters
----------
path : Path
The path of the directory to ingest.
query : ParsedQuery
The parsed query object containing information about the repository and query parameters.
It should have an attribute `ignore_patterns` which is either None or a set of strings.
"""
path_gitingest = path / ".gitingest"
if not path_gitingest.is_file():
return
try:
with path_gitingest.open("rb") as f:
data = tomli.load(f)
except tomli.TOMLDecodeError as exc:
warnings.warn(f"Invalid TOML in {path_gitingest}: {exc}", UserWarning)
return
config_section = data.get("config", {})
ignore_patterns = config_section.get("ignore_patterns")
if not ignore_patterns:
return
# If a single string is provided, make it a list of one element
if isinstance(ignore_patterns, str):
ignore_patterns = [ignore_patterns]
if not isinstance(ignore_patterns, (list, set)):
warnings.warn(
f"Expected a list/set for 'ignore_patterns', got {type(ignore_patterns)} in {path_gitingest}. Skipping.",
UserWarning,
)
return
# Filter out duplicated patterns
ignore_patterns = set(ignore_patterns)
# Filter out any non-string entries
valid_patterns = {pattern for pattern in ignore_patterns if isinstance(pattern, str)}
invalid_patterns = ignore_patterns - valid_patterns
if invalid_patterns:
warnings.warn(f"Ignore patterns {invalid_patterns} are not strings. Skipping.", UserWarning)
if not valid_patterns:
return
if query.ignore_patterns is None:
query.ignore_patterns = valid_patterns
else:
query.ignore_patterns.update(valid_patterns)
return

View file

@ -1,30 +1,26 @@
""" This module contains functions to parse and validate input sources and patterns. """
import os
import re
import string
import uuid
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Set, Tuple, Union
from typing import List, Optional, Set, Union
from urllib.parse import unquote, urlparse
from gitingest.cloning import CloneConfig, _check_repo_exists, fetch_remote_branch_list
from gitingest.config import MAX_FILE_SIZE, TMP_BASE_PATH
from gitingest.exceptions import InvalidPatternError
from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
from gitingest.repository_clone import CloneConfig, _check_repo_exists, fetch_remote_branch_list
HEX_DIGITS: Set[str] = set(string.hexdigits)
KNOWN_GIT_HOSTS: List[str] = [
"github.com",
"gitlab.com",
"bitbucket.org",
"gitea.com",
"codeberg.org",
"gist.github.com",
]
from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS
from gitingest.utils.query_parser_utils import (
KNOWN_GIT_HOSTS,
_get_user_and_repo_from_path,
_is_valid_git_commit_hash,
_is_valid_pattern,
_normalize_pattern,
_validate_host,
_validate_url_scheme,
)
@dataclass
@ -71,6 +67,7 @@ class ParsedQuery: # pylint: disable=too-many-instance-attributes
commit=self.commit,
branch=self.branch,
subpath=self.subpath,
blob=self.type == "blob",
)
@ -110,10 +107,10 @@ async def parse_query(
# Determine the parsing method based on the source type
if from_web or urlparse(source).scheme in ("https", "http") or any(h in source for h in KNOWN_GIT_HOSTS):
# We either have a full URL or a domain-less slug
parsed_query = await _parse_repo_source(source)
parsed_query = await _parse_remote_repo(source)
else:
# Local path scenario
parsed_query = _parse_path(source)
parsed_query = _parse_local_dir_path(source)
# Combine default ignore patterns + custom patterns
ignore_patterns_set = DEFAULT_IGNORE_PATTERNS.copy()
@ -123,7 +120,8 @@ async def parse_query(
# Process include patterns and override ignore patterns accordingly
if include_patterns:
parsed_include = _parse_patterns(include_patterns)
ignore_patterns_set = _override_ignore_patterns(ignore_patterns_set, include_patterns=parsed_include)
# Override ignore patterns with include patterns
ignore_patterns_set = set(ignore_patterns_set) - set(parsed_include)
else:
parsed_include = None
@ -144,7 +142,7 @@ async def parse_query(
)
async def _parse_repo_source(source: str) -> ParsedQuery:
async def _parse_remote_repo(source: str) -> ParsedQuery:
"""
Parse a repository URL into a structured query dictionary.
@ -169,7 +167,7 @@ async def _parse_repo_source(source: str) -> ParsedQuery:
parsed_url = urlparse(source)
if parsed_url.scheme:
_validate_scheme(parsed_url.scheme)
_validate_url_scheme(parsed_url.scheme)
_validate_host(parsed_url.netloc.lower())
else: # Will be of the form 'host/user/repo' or 'user/repo'
@ -251,8 +249,8 @@ async def _configure_branch_and_subpath(remaining_parts: List[str], url: str) ->
try:
# Fetch the list of branches from the remote repository
branches: List[str] = await fetch_remote_branch_list(url)
except RuntimeError as e:
warnings.warn(f"Warning: Failed to fetch branch list: {e}", RuntimeWarning)
except RuntimeError as exc:
warnings.warn(f"Warning: Failed to fetch branch list: {exc}", RuntimeWarning)
return remaining_parts.pop(0)
branch = []
@ -265,49 +263,6 @@ async def _configure_branch_and_subpath(remaining_parts: List[str], url: str) ->
return None
def _is_valid_git_commit_hash(commit: str) -> bool:
"""
Validate if the provided string is a valid Git commit hash.
This function checks if the commit hash is a 40-character string consisting only
of hexadecimal digits, which is the standard format for Git commit hashes.
Parameters
----------
commit : str
The string to validate as a Git commit hash.
Returns
-------
bool
True if the string is a valid 40-character Git commit hash, otherwise False.
"""
return len(commit) == 40 and all(c in HEX_DIGITS for c in commit)
def _normalize_pattern(pattern: str) -> str:
"""
Normalize the given pattern by removing leading separators and appending a wildcard.
This function processes the pattern string by stripping leading directory separators
and appending a wildcard (`*`) if the pattern ends with a separator.
Parameters
----------
pattern : str
The pattern to normalize.
Returns
-------
str
The normalized pattern.
"""
pattern = pattern.lstrip(os.sep)
if pattern.endswith(os.sep):
pattern += "*"
return pattern
def _parse_patterns(pattern: Union[str, Set[str]]) -> Set[str]:
"""
Parse and validate file/directory patterns for inclusion or exclusion.
@ -349,26 +304,7 @@ def _parse_patterns(pattern: Union[str, Set[str]]) -> Set[str]:
return {_normalize_pattern(p) for p in parsed_patterns}
def _override_ignore_patterns(ignore_patterns: Set[str], include_patterns: Set[str]) -> Set[str]:
"""
Remove patterns from ignore_patterns that are present in include_patterns using set difference.
Parameters
----------
ignore_patterns : Set[str]
The set of ignore patterns to filter.
include_patterns : Set[str]
The set of include patterns to remove from ignore_patterns.
Returns
-------
Set[str]
The filtered set of ignore patterns.
"""
return set(ignore_patterns) - set(include_patterns)
def _parse_path(path_str: str) -> ParsedQuery:
def _parse_local_dir_path(path_str: str) -> ParsedQuery:
"""
Parse the given file path into a structured query dictionary.
@ -383,37 +319,17 @@ def _parse_path(path_str: str) -> ParsedQuery:
A dictionary containing the parsed details of the file path.
"""
path_obj = Path(path_str).resolve()
slug = path_obj.name if path_str == "." else path_str.strip("/")
return ParsedQuery(
user_name=None,
repo_name=None,
url=None,
local_path=path_obj,
slug=f"{path_obj.parent.name}/{path_obj.name}",
slug=slug,
id=str(uuid.uuid4()),
)
def _is_valid_pattern(pattern: str) -> bool:
"""
Validate if the given pattern contains only valid characters.
This function checks if the pattern contains only alphanumeric characters or one
of the following allowed characters: dash (`-`), underscore (`_`), dot (`.`),
forward slash (`/`), plus (`+`), asterisk (`*`), or the at sign (`@`).
Parameters
----------
pattern : str
The pattern to validate.
Returns
-------
bool
True if the pattern is valid, otherwise False.
"""
return all(c.isalnum() or c in "-_./+*@" for c in pattern)
async def try_domains_for_user_and_repo(user_name: str, repo_name: str) -> str:
"""
Attempt to find a valid repository host for the given user_name and repo_name.
@ -440,64 +356,3 @@ async def try_domains_for_user_and_repo(user_name: str, repo_name: str) -> str:
if await _check_repo_exists(candidate):
return domain
raise ValueError(f"Could not find a valid repository host for '{user_name}/{repo_name}'.")
def _get_user_and_repo_from_path(path: str) -> Tuple[str, str]:
"""
Extract the user and repository names from a given path.
Parameters
----------
path : str
The path to extract the user and repository names from.
Returns
-------
Tuple[str, str]
A tuple containing the user and repository names.
Raises
------
ValueError
If the path does not contain at least two parts.
"""
path_parts = path.lower().strip("/").split("/")
if len(path_parts) < 2:
raise ValueError(f"Invalid repository URL '{path}'")
return path_parts[0], path_parts[1]
def _validate_host(host: str) -> None:
"""
Validate the given host against the known Git hosts.
Parameters
----------
host : str
The host to validate.
Raises
------
ValueError
If the host is not a known Git host.
"""
if host not in KNOWN_GIT_HOSTS:
raise ValueError(f"Unknown domain '{host}' in URL")
def _validate_scheme(scheme: str) -> None:
"""
Validate the given scheme against the known schemes.
Parameters
----------
scheme : str
The scheme to validate.
Raises
------
ValueError
If the scheme is not 'http' or 'https'.
"""
if scheme not in ("https", "http"):
raise ValueError(f"Invalid URL scheme '{scheme}' in URL")

View file

@ -5,10 +5,10 @@ import inspect
import shutil
from typing import Optional, Set, Tuple, Union
from gitingest.cloning import clone_repo
from gitingest.config import TMP_BASE_PATH
from gitingest.query_ingestion import run_ingest_query
from gitingest.query_parser import ParsedQuery, parse_query
from gitingest.repository_clone import clone_repo
from gitingest.ingestion import ingest_query
from gitingest.query_parsing import ParsedQuery, parse_query
async def ingest_async(
@ -83,7 +83,7 @@ async def ingest_async(
repo_cloned = True
summary, tree, content = run_ingest_query(parsed_query)
summary, tree, content = ingest_query(parsed_query)
if output is not None:
with open(output, "w", encoding="utf-8") as f:
@ -93,7 +93,7 @@ async def ingest_async(
finally:
# Clean up the temporary directory if it was created
if repo_cloned:
shutil.rmtree(TMP_BASE_PATH)
shutil.rmtree(TMP_BASE_PATH, ignore_errors=True)
def ingest(

View file

View file

@ -17,7 +17,7 @@ DEFAULT_IGNORE_PATTERNS: Set[str] = {
".hypothesis",
"poetry.lock",
"Pipfile.lock",
# JavaScript/Node
# JavaScript/FileSystemNode
"node_modules",
"bower_components",
"package-lock.json",
@ -157,4 +157,6 @@ DEFAULT_IGNORE_PATTERNS: Set[str] = {
"*.tfstate*",
## Dependencies in various languages
"vendor/",
# Gitingest
"digest.txt",
}

View file

@ -0,0 +1,97 @@
""" Utility functions for the ingestion process. """
import locale
import platform
from fnmatch import fnmatch
from pathlib import Path
from typing import List, Set
try:
locale.setlocale(locale.LC_ALL, "")
except locale.Error:
locale.setlocale(locale.LC_ALL, "C")
def _get_encoding_list() -> List[str]:
"""
Get list of encodings to try, prioritized for the current platform.
Returns
-------
List[str]
List of encoding names to try in priority order, starting with the
platform's default encoding followed by common fallback encodings.
"""
encodings = [locale.getpreferredencoding(), "utf-8", "utf-16", "utf-16le", "utf-8-sig", "latin"]
if platform.system() == "Windows":
encodings += ["cp1252", "iso-8859-1"]
return encodings
def _should_include(path: Path, base_path: Path, include_patterns: Set[str]) -> bool:
"""
Determine if the given file or directory path matches any of the include patterns.
This function checks whether the relative path of a file or directory matches any of the specified patterns. If a
match is found, it returns `True`, indicating that the file or directory should be included in further processing.
Parameters
----------
path : Path
The absolute path of the file or directory to check.
base_path : Path
The base directory from which the relative path is calculated.
include_patterns : Set[str]
A set of patterns to check against the relative path.
Returns
-------
bool
`True` if the path matches any of the include patterns, `False` otherwise.
"""
try:
rel_path = path.relative_to(base_path)
except ValueError:
# If path is not under base_path at all
return False
rel_str = str(rel_path)
for pattern in include_patterns:
if fnmatch(rel_str, pattern):
return True
return False
def _should_exclude(path: Path, base_path: Path, ignore_patterns: Set[str]) -> bool:
"""
Determine if the given file or directory path matches any of the ignore patterns.
This function checks whether the relative path of a file or directory matches
any of the specified ignore patterns. If a match is found, it returns `True`, indicating
that the file or directory should be excluded from further processing.
Parameters
----------
path : Path
The absolute path of the file or directory to check.
base_path : Path
The base directory from which the relative path is calculated.
ignore_patterns : Set[str]
A set of patterns to check against the relative path.
Returns
-------
bool
`True` if the path matches any of the ignore patterns, `False` otherwise.
"""
try:
rel_path = path.relative_to(base_path)
except ValueError:
# If path is not under base_path at all
return True
rel_str = str(rel_path)
for pattern in ignore_patterns:
if pattern and fnmatch(rel_str, pattern):
return True
return False

View file

@ -33,8 +33,8 @@ def process_notebook(file: Path, include_output: bool = True) -> str:
try:
with file.open(encoding="utf-8") as f:
notebook: Dict[str, Any] = json.load(f)
except json.JSONDecodeError as e:
raise InvalidNotebookError(f"Invalid JSON in notebook: {file}") from e
except json.JSONDecodeError as exc:
raise InvalidNotebookError(f"Invalid JSON in notebook: {file}") from exc
# Check if the notebook contains worksheets
worksheets = notebook.get("worksheets")

View file

@ -0,0 +1,39 @@
""" Utility functions for working with file paths. """
import os
import platform
from pathlib import Path
def _is_safe_symlink(symlink_path: Path, base_path: Path) -> bool:
"""
Check if a symlink points to a location within the base directory.
This function resolves the target of a symlink and ensures it is within the specified
base directory, returning `True` if it is safe, or `False` if the symlink points outside
the base directory.
Parameters
----------
symlink_path : Path
The path of the symlink to check.
base_path : Path
The base directory to ensure the symlink points within.
Returns
-------
bool
`True` if the symlink points within the base directory, `False` otherwise.
"""
try:
if platform.system() == "Windows":
if not os.path.islink(str(symlink_path)):
return False
target_path = symlink_path.resolve()
base_resolved = base_path.resolve()
return base_resolved in target_path.parents or target_path == base_resolved
except (OSError, ValueError):
# If there's any error resolving the paths, consider it unsafe
return False

View file

@ -0,0 +1,142 @@
""" Utility functions for parsing and validating query parameters. """
import os
import string
from typing import List, Set, Tuple
HEX_DIGITS: Set[str] = set(string.hexdigits)
KNOWN_GIT_HOSTS: List[str] = [
"github.com",
"gitlab.com",
"bitbucket.org",
"gitea.com",
"codeberg.org",
"gist.github.com",
]
def _is_valid_git_commit_hash(commit: str) -> bool:
"""
Validate if the provided string is a valid Git commit hash.
This function checks if the commit hash is a 40-character string consisting only
of hexadecimal digits, which is the standard format for Git commit hashes.
Parameters
----------
commit : str
The string to validate as a Git commit hash.
Returns
-------
bool
True if the string is a valid 40-character Git commit hash, otherwise False.
"""
return len(commit) == 40 and all(c in HEX_DIGITS for c in commit)
def _is_valid_pattern(pattern: str) -> bool:
"""
Validate if the given pattern contains only valid characters.
This function checks if the pattern contains only alphanumeric characters or one
of the following allowed characters: dash (`-`), underscore (`_`), dot (`.`),
forward slash (`/`), plus (`+`), asterisk (`*`), or the at sign (`@`).
Parameters
----------
pattern : str
The pattern to validate.
Returns
-------
bool
True if the pattern is valid, otherwise False.
"""
return all(c.isalnum() or c in "-_./+*@" for c in pattern)
def _validate_host(host: str) -> None:
"""
Validate the given host against the known Git hosts.
Parameters
----------
host : str
The host to validate.
Raises
------
ValueError
If the host is not a known Git host.
"""
if host not in KNOWN_GIT_HOSTS:
raise ValueError(f"Unknown domain '{host}' in URL")
def _validate_url_scheme(scheme: str) -> None:
"""
Validate the given scheme against the known schemes.
Parameters
----------
scheme : str
The scheme to validate.
Raises
------
ValueError
If the scheme is not 'http' or 'https'.
"""
if scheme not in ("https", "http"):
raise ValueError(f"Invalid URL scheme '{scheme}' in URL")
def _get_user_and_repo_from_path(path: str) -> Tuple[str, str]:
"""
Extract the user and repository names from a given path.
Parameters
----------
path : str
The path to extract the user and repository names from.
Returns
-------
Tuple[str, str]
A tuple containing the user and repository names.
Raises
------
ValueError
If the path does not contain at least two parts.
"""
path_parts = path.lower().strip("/").split("/")
if len(path_parts) < 2:
raise ValueError(f"Invalid repository URL '{path}'")
return path_parts[0], path_parts[1]
def _normalize_pattern(pattern: str) -> str:
"""
Normalize the given pattern by removing leading separators and appending a wildcard.
This function processes the pattern string by stripping leading directory separators
and appending a wildcard (`*`) if the pattern ends with a separator.
Parameters
----------
pattern : str
The pattern to normalize.
Returns
-------
str
The normalized pattern.
"""
pattern = pattern.lstrip(os.sep)
if pattern.endswith(os.sep):
pattern += "*"
return pattern

View file

@ -0,0 +1,48 @@
""" Utility functions for checking whether a file is likely a text file or a binary file. """
from pathlib import Path
from gitingest.utils.ingestion_utils import _get_encoding_list
def is_textfile(path: Path) -> bool:
"""
Determine whether a file is likely a text file or a binary file using various heuristics.
Parameters
----------
path : Path
The path to the file to check.
Returns
-------
bool
True if the file is likely textual; False if it appears to be binary.
"""
# Attempt to read a small portion (up to 1024 bytes) of the file in binary mode.
try:
with path.open("rb") as f:
chunk = f.read(1024)
except OSError:
# If we cannot read the file for any reason, treat it as non-textual.
return False
# If the file is empty, we treat it as text.
if not chunk:
return True
# Look for obvious binary indicators such as null (0x00) or 0xFF bytes.
if b"\x00" in chunk or b"\xff" in chunk:
return False
for encoding in _get_encoding_list():
try:
with path.open(encoding=encoding) as f:
f.read()
return True
except UnicodeDecodeError:
continue
except OSError:
return False
return False

View file

@ -5,9 +5,9 @@ from functools import partial
from fastapi import Request
from starlette.templating import _TemplateResponse
from gitingest.query_ingestion import run_ingest_query
from gitingest.query_parser import ParsedQuery, parse_query
from gitingest.repository_clone import clone_repo
from gitingest.cloning import clone_repo
from gitingest.ingestion import ingest_query
from gitingest.query_parsing import ParsedQuery, parse_query
from server.server_config import EXAMPLE_REPOS, MAX_DISPLAY_SIZE, templates
from server.server_utils import Colors, log_slider_to_size
@ -86,20 +86,19 @@ async def process_query(
clone_config = parsed_query.extact_clone_config()
await clone_repo(clone_config)
summary, tree, content = run_ingest_query(parsed_query)
with open(f"{parsed_query.local_path}.txt", "w", encoding="utf-8") as f:
summary, tree, content = ingest_query(parsed_query)
with open(f"{clone_config.local_path}.txt", "w", encoding="utf-8") as f:
f.write(tree + "\n" + content)
except Exception as e:
except Exception as exc:
# hack to print error message when query is not defined
if "query" in locals() and parsed_query is not None and isinstance(parsed_query, dict):
_print_error(parsed_query["url"], e, max_file_size, pattern_type, pattern)
_print_error(parsed_query["url"], exc, max_file_size, pattern_type, pattern)
else:
print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="")
print(f"{Colors.RED}{e}{Colors.END}")
print(f"{Colors.RED}{exc}{Colors.END}")
context["error_message"] = f"Error: {e}"
if "405" in str(e):
context["error_message"] = f"Error: {exc}"
if "405" in str(exc):
context["error_message"] = (
"Repository not found. Please make sure it is public (private repositories will be supported soon)"
)

View file

@ -104,8 +104,8 @@ async def _remove_old_repositories():
await _process_folder(folder)
except Exception as e:
print(f"Error in _remove_old_repositories: {e}")
except Exception as exc:
print(f"Error in _remove_old_repositories: {exc}")
await asyncio.sleep(60)
@ -132,14 +132,14 @@ async def _process_folder(folder: Path) -> None:
with open("history.txt", mode="a", encoding="utf-8") as history:
history.write(f"{repo_url}\n")
except Exception as e:
print(f"Error logging repository URL for {folder}: {e}")
except Exception as exc:
print(f"Error logging repository URL for {folder}: {exc}")
# Delete the folder
try:
shutil.rmtree(folder)
except Exception as e:
print(f"Error deleting {folder}: {e}")
except Exception as exc:
print(f"Error deleting {folder}: {exc}")
def log_slider_to_size(position: int) -> int:

View file

@ -11,7 +11,7 @@ from typing import Any, Callable, Dict
import pytest
from gitingest.query_parser import ParsedQuery
from gitingest.query_parsing import ParsedQuery
WriteNotebookFunc = Callable[[str, Dict[str, Any]], Path]

View file

@ -9,7 +9,7 @@ from typing import List
import pytest
from gitingest.query_parser import parse_query
from gitingest.query_parsing import parse_query
@pytest.mark.parametrize(

View file

@ -1,5 +1,5 @@
"""
Tests for the `query_parser` module.
Tests for the `query_parsing` module.
These tests cover URL parsing, pattern parsing, and handling of branches/subpaths for HTTP(S) repositories and local
paths.
@ -10,17 +10,17 @@ from unittest.mock import AsyncMock, patch
import pytest
from gitingest.ignore_patterns import DEFAULT_IGNORE_PATTERNS
from gitingest.query_parser import _parse_patterns, _parse_repo_source, parse_query
from gitingest.query_parsing import _parse_patterns, _parse_remote_repo, parse_query
from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS
@pytest.mark.asyncio
async def test_parse_url_valid_https() -> None:
"""
Test `_parse_repo_source` with valid HTTPS URLs.
Test `_parse_remote_repo` with valid HTTPS URLs.
Given various HTTPS URLs on supported platforms:
When `_parse_repo_source` is called,
When `_parse_remote_repo` is called,
Then user name, repo name, and the URL should be extracted correctly.
"""
test_cases = [
@ -32,7 +32,7 @@ async def test_parse_url_valid_https() -> None:
"https://gist.github.com/user/repo",
]
for url in test_cases:
parsed_query = await _parse_repo_source(url)
parsed_query = await _parse_remote_repo(url)
assert parsed_query.user_name == "user"
assert parsed_query.repo_name == "repo"
@ -42,10 +42,10 @@ async def test_parse_url_valid_https() -> None:
@pytest.mark.asyncio
async def test_parse_url_valid_http() -> None:
"""
Test `_parse_repo_source` with valid HTTP URLs.
Test `_parse_remote_repo` with valid HTTP URLs.
Given various HTTP URLs on supported platforms:
When `_parse_repo_source` is called,
When `_parse_remote_repo` is called,
Then user name, repo name, and the slug should be extracted correctly.
"""
test_cases = [
@ -57,7 +57,7 @@ async def test_parse_url_valid_http() -> None:
"http://gist.github.com/user/repo",
]
for url in test_cases:
parsed_query = await _parse_repo_source(url)
parsed_query = await _parse_remote_repo(url)
assert parsed_query.user_name == "user"
assert parsed_query.repo_name == "repo"
@ -67,15 +67,15 @@ async def test_parse_url_valid_http() -> None:
@pytest.mark.asyncio
async def test_parse_url_invalid() -> None:
"""
Test `_parse_repo_source` with an invalid URL.
Test `_parse_remote_repo` with an invalid URL.
Given an HTTPS URL lacking a repository structure (e.g., "https://github.com"),
When `_parse_repo_source` is called,
When `_parse_remote_repo` is called,
Then a ValueError should be raised indicating an invalid repository URL.
"""
url = "https://github.com"
with pytest.raises(ValueError, match="Invalid repository URL"):
await _parse_repo_source(url)
await _parse_remote_repo(url)
@pytest.mark.asyncio
@ -146,20 +146,18 @@ async def test_parse_query_invalid_pattern() -> None:
@pytest.mark.asyncio
async def test_parse_url_with_subpaths() -> None:
"""
Test `_parse_repo_source` with a URL containing branch and subpath.
Test `_parse_remote_repo` with a URL containing branch and subpath.
Given a URL referencing a branch ("main") and a subdir ("subdir/file"):
When `_parse_repo_source` is called with remote branch fetching,
When `_parse_remote_repo` is called with remote branch fetching,
Then user, repo, branch, and subpath should be identified correctly.
"""
url = "https://github.com/user/repo/tree/main/subdir/file"
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_run_git_command:
mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch(
"gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock
) as mock_fetch_branches:
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_run_command:
mock_run_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch("gitingest.cloning.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.return_value = ["main", "dev", "feature-branch"]
parsed_query = await _parse_repo_source(url)
parsed_query = await _parse_remote_repo(url)
assert parsed_query.user_name == "user"
assert parsed_query.repo_name == "repo"
@ -170,15 +168,15 @@ async def test_parse_url_with_subpaths() -> None:
@pytest.mark.asyncio
async def test_parse_url_invalid_repo_structure() -> None:
"""
Test `_parse_repo_source` with a URL missing a repository name.
Test `_parse_remote_repo` with a URL missing a repository name.
Given a URL like "https://github.com/user":
When `_parse_repo_source` is called,
When `_parse_remote_repo` is called,
Then a ValueError should be raised indicating an invalid repository URL.
"""
url = "https://github.com/user"
with pytest.raises(ValueError, match="Invalid repository URL"):
await _parse_repo_source(url)
await _parse_remote_repo(url)
def test_parse_patterns_valid() -> None:
@ -279,7 +277,7 @@ async def test_parse_query_local_path() -> None:
assert parsed_query.local_path.parts[-len(tail.parts) :] == tail.parts
assert parsed_query.id is not None
assert parsed_query.slug == "user/project"
assert parsed_query.slug == "home/user/project"
@pytest.mark.asyncio
@ -326,21 +324,19 @@ async def test_parse_query_empty_source() -> None:
)
async def test_parse_url_branch_and_commit_distinction(url: str, expected_branch: str, expected_commit: str) -> None:
"""
Test `_parse_repo_source` distinguishing branch vs. commit hash.
Test `_parse_remote_repo` distinguishing branch vs. commit hash.
Given either a branch URL (e.g., ".../tree/main") or a 40-character commit URL:
When `_parse_repo_source` is called with branch fetching,
When `_parse_remote_repo` is called with branch fetching,
Then the function should correctly set `branch` or `commit` based on the URL content.
"""
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_run_git_command:
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_run_command:
# Mocking the return value to include 'main' and some additional branches
mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch(
"gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock
) as mock_fetch_branches:
mock_run_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"")
with patch("gitingest.cloning.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.return_value = ["main", "dev", "feature-branch"]
parsed_query = await _parse_repo_source(url)
parsed_query = await _parse_remote_repo(url)
# Verify that `branch` and `commit` match our expectations
assert parsed_query.branch == expected_branch
@ -366,14 +362,14 @@ async def test_parse_query_uuid_uniqueness() -> None:
@pytest.mark.asyncio
async def test_parse_url_with_query_and_fragment() -> None:
"""
Test `_parse_repo_source` with query parameters and a fragment.
Test `_parse_remote_repo` with query parameters and a fragment.
Given a URL like "https://github.com/user/repo?arg=value#fragment":
When `_parse_repo_source` is called,
When `_parse_remote_repo` is called,
Then those parts should be stripped, leaving a clean user/repo URL.
"""
url = "https://github.com/user/repo?arg=value#fragment"
parsed_query = await _parse_repo_source(url)
parsed_query = await _parse_remote_repo(url)
assert parsed_query.user_name == "user"
assert parsed_query.repo_name == "repo"
@ -383,15 +379,15 @@ async def test_parse_url_with_query_and_fragment() -> None:
@pytest.mark.asyncio
async def test_parse_url_unsupported_host() -> None:
"""
Test `_parse_repo_source` with an unsupported host.
Test `_parse_remote_repo` with an unsupported host.
Given "https://only-domain.com":
When `_parse_repo_source` is called,
When `_parse_remote_repo` is called,
Then a ValueError should be raised for the unknown domain.
"""
url = "https://only-domain.com"
with pytest.raises(ValueError, match="Unknown domain 'only-domain.com' in URL"):
await _parse_repo_source(url)
await _parse_remote_repo(url)
@pytest.mark.asyncio
@ -428,13 +424,13 @@ async def test_parse_query_with_branch() -> None:
)
async def test_parse_repo_source_with_failed_git_command(url, expected_branch, expected_subpath):
"""
Test `_parse_repo_source` when git fetch fails.
Test `_parse_remote_repo` when git fetch fails.
Given a URL referencing a branch, but Git fetching fails:
When `_parse_repo_source` is called,
When `_parse_remote_repo` is called,
Then it should fall back to path components for branch identification.
"""
with patch("gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches:
with patch("gitingest.cloning.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches:
mock_fetch_branches.side_effect = Exception("Failed to fetch branch list")
with pytest.warns(
@ -443,7 +439,7 @@ async def test_parse_repo_source_with_failed_git_command(url, expected_branch, e
"git ls-remote --heads https://github.com/user/repo",
):
parsed_query = await _parse_repo_source(url)
parsed_query = await _parse_remote_repo(url)
assert parsed_query.branch == expected_branch
assert parsed_query.subpath == expected_subpath
@ -463,23 +459,21 @@ async def test_parse_repo_source_with_failed_git_command(url, expected_branch, e
)
async def test_parse_repo_source_with_various_url_patterns(url, expected_branch, expected_subpath):
"""
Test `_parse_repo_source` with various URL patterns.
Test `_parse_remote_repo` with various URL patterns.
Given multiple branch/blob patterns (including nonexistent branches):
When `_parse_repo_source` is called with remote branch fetching,
When `_parse_remote_repo` is called with remote branch fetching,
Then the correct branch/subpath should be set or None if unmatched.
"""
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_run_git_command:
with patch(
"gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock
) as mock_fetch_branches:
mock_run_git_command.return_value = (
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_run_command:
with patch("gitingest.cloning.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches:
mock_run_command.return_value = (
b"refs/heads/feature/fix1\nrefs/heads/main\nrefs/heads/feature-branch\nrefs/heads/fix\n",
b"",
)
mock_fetch_branches.return_value = ["feature/fix1", "main", "feature-branch"]
parsed_query = await _parse_repo_source(url)
parsed_query = await _parse_remote_repo(url)
assert parsed_query.branch == expected_branch
assert parsed_query.subpath == expected_subpath

View file

@ -5,17 +5,17 @@ import os
from click.testing import CliRunner
from gitingest.cli import main
from gitingest.config import MAX_FILE_SIZE, OUTPUT_FILE_PATH
from gitingest.config import MAX_FILE_SIZE, OUTPUT_FILE_NAME
def test_cli_with_default_options():
runner = CliRunner()
result = runner.invoke(main, ["./"])
output_lines = result.output.strip().split("\n")
assert f"Analysis complete! Output written to: {OUTPUT_FILE_PATH}" in output_lines
assert os.path.exists(OUTPUT_FILE_PATH), f"Output file was not created at {OUTPUT_FILE_PATH}"
assert f"Analysis complete! Output written to: {OUTPUT_FILE_NAME}" in output_lines
assert os.path.exists(OUTPUT_FILE_NAME), f"Output file was not created at {OUTPUT_FILE_NAME}"
os.remove(OUTPUT_FILE_PATH)
os.remove(OUTPUT_FILE_NAME)
def test_cli_with_options():
@ -25,7 +25,7 @@ def test_cli_with_options():
[
"./",
"--output",
str(OUTPUT_FILE_PATH),
str(OUTPUT_FILE_NAME),
"--max-size",
str(MAX_FILE_SIZE),
"--exclude-pattern",
@ -35,7 +35,7 @@ def test_cli_with_options():
],
)
output_lines = result.output.strip().split("\n")
assert f"Analysis complete! Output written to: {OUTPUT_FILE_PATH}" in output_lines
assert os.path.exists(OUTPUT_FILE_PATH), f"Output file was not created at {OUTPUT_FILE_PATH}"
assert f"Analysis complete! Output written to: {OUTPUT_FILE_NAME}" in output_lines
assert os.path.exists(OUTPUT_FILE_NAME), f"Output file was not created at {OUTPUT_FILE_NAME}"
os.remove(OUTPUT_FILE_PATH)
os.remove(OUTPUT_FILE_NAME)

View file

@ -46,8 +46,8 @@ def cleanup_temp_directories():
if temp_dir.exists():
try:
shutil.rmtree(temp_dir)
except PermissionError as e:
print(f"Error cleaning up {temp_dir}: {e}")
except PermissionError as exc:
print(f"Error cleaning up {temp_dir}: {exc}")
@pytest.fixture(scope="module", autouse=True)

46
tests/test_ingestion.py Normal file
View file

@ -0,0 +1,46 @@
"""
Tests for the `query_ingestion` module.
These tests validate directory scanning, file content extraction, notebook handling, and the overall ingestion logic,
including filtering patterns and subpaths.
"""
from pathlib import Path
from gitingest.ingestion import ingest_query
from gitingest.query_parsing import ParsedQuery
def test_run_ingest_query(temp_directory: Path, sample_query: ParsedQuery) -> None:
"""
Test `ingest_query` to ensure it processes the directory and returns expected results.
Given a directory with .txt and .py files:
When `ingest_query` is invoked,
Then it should produce a summary string listing the files analyzed and a combined content string.
"""
sample_query.local_path = temp_directory
sample_query.subpath = "/"
sample_query.type = None
summary, _, content = ingest_query(sample_query)
assert "Repository: test_user/test_repo" in summary
assert "Files analyzed: 8" in summary
# Check presence of key files in the content
assert "src/subfile1.txt" in content
assert "src/subfile2.py" in content
assert "src/subdir/file_subdir.txt" in content
assert "src/subdir/file_subdir.py" in content
assert "file1.txt" in content
assert "file2.py" in content
assert "dir1/file_dir1.txt" in content
assert "dir2/file_dir2.txt" in content
# TODO: Additional tests:
# - Multiple include patterns, e.g. ["*.txt", "*.py"] or ["/src/*", "*.txt"].
# - Edge cases with weird file names or deep subdirectory structures.
# TODO : def test_include_txt_pattern
# TODO : def test_include_nonexistent_extension

View file

@ -8,7 +8,7 @@ empty cells, outputs, etc.) are handled appropriately.
import pytest
from gitingest.notebook_utils import process_notebook
from gitingest.utils.notebook_utils import process_notebook
from tests.conftest import WriteNotebookFunc

View file

@ -1,209 +0,0 @@
"""
Tests for the `query_ingestion` module.
These tests validate directory scanning, file content extraction, notebook handling, and the overall ingestion logic,
including filtering patterns and subpaths.
"""
from pathlib import Path
from unittest.mock import patch
import pytest
from gitingest.query_ingestion import _extract_files_content, _read_file_content, _scan_directory, run_ingest_query
from gitingest.query_parser import ParsedQuery
def test_scan_directory(temp_directory: Path, sample_query: ParsedQuery) -> None:
"""
Test `_scan_directory` with default settings.
Given a populated test directory:
When `_scan_directory` is called,
Then it should return a structured node containing the correct directories and file counts.
"""
sample_query.local_path = temp_directory
result = _scan_directory(temp_directory, query=sample_query)
assert result is not None, "Expected a valid directory node structure"
assert result["type"] == "directory"
assert result["file_count"] == 8, "Should count all .txt and .py files"
assert result["dir_count"] == 4, "Should include src, src/subdir, dir1, dir2"
assert len(result["children"]) == 5, "Should contain file1.txt, file2.py, src, dir1, dir2"
def test_extract_files_content(temp_directory: Path, sample_query: ParsedQuery) -> None:
"""
Test `_extract_files_content` to ensure it gathers contents from scanned nodes.
Given a populated test directory:
When `_extract_files_content` is called with a valid scan result,
Then it should return a list of file info containing the correct filenames and paths.
"""
sample_query.local_path = temp_directory
nodes = _scan_directory(temp_directory, query=sample_query)
assert nodes is not None, "Expected a valid scan result"
files = _extract_files_content(query=sample_query, node=nodes)
assert len(files) == 8, "Should extract all .txt and .py files"
paths = [f["path"] for f in files]
# Verify presence of key files
assert any("file1.txt" in p for p in paths)
assert any("subfile1.txt" in p for p in paths)
assert any("file2.py" in p for p in paths)
assert any("subfile2.py" in p for p in paths)
assert any("file_subdir.txt" in p for p in paths)
assert any("file_dir1.txt" in p for p in paths)
assert any("file_dir2.txt" in p for p in paths)
def test_read_file_content_with_notebook(tmp_path: Path) -> None:
"""
Test `_read_file_content` with a notebook file.
Given a minimal .ipynb file:
When `_read_file_content` is called,
Then `process_notebook` should be invoked to handle notebook-specific content.
"""
notebook_path = tmp_path / "dummy_notebook.ipynb"
notebook_path.write_text("{}", encoding="utf-8") # minimal JSON
with patch("gitingest.query_ingestion.process_notebook") as mock_process:
_read_file_content(notebook_path)
mock_process.assert_called_once_with(notebook_path)
def test_read_file_content_with_non_notebook(tmp_path: Path):
"""
Test `_read_file_content` with a non-notebook file.
Given a standard .py file:
When `_read_file_content` is called,
Then `process_notebook` should not be triggered.
"""
py_file_path = tmp_path / "dummy_file.py"
py_file_path.write_text("print('Hello')", encoding="utf-8")
with patch("gitingest.query_ingestion.process_notebook") as mock_process:
_read_file_content(py_file_path)
mock_process.assert_not_called()
def test_include_txt_pattern(temp_directory: Path, sample_query: ParsedQuery) -> None:
"""
Test including only .txt files using a pattern like `*.txt`.
Given a directory with mixed .txt and .py files:
When `include_patterns` is set to `*.txt`,
Then `_scan_directory` should include only .txt files, excluding .py files.
"""
sample_query.local_path = temp_directory
sample_query.include_patterns = {"*.txt"}
result = _scan_directory(temp_directory, query=sample_query)
assert result is not None, "Expected a valid directory node structure"
files = _extract_files_content(query=sample_query, node=result)
file_paths = [f["path"] for f in files]
assert len(files) == 5, "Should find exactly 5 .txt files"
assert all(path.endswith(".txt") for path in file_paths), "Should only include .txt files"
expected_files = ["file1.txt", "subfile1.txt", "file_subdir.txt", "file_dir1.txt", "file_dir2.txt"]
for expected_file in expected_files:
assert any(expected_file in path for path in file_paths), f"Missing expected file: {expected_file}"
assert not any(path.endswith(".py") for path in file_paths), "No .py files should be included"
def test_include_nonexistent_extension(temp_directory: Path, sample_query: ParsedQuery) -> None:
"""
Test including a nonexistent extension (e.g., `*.query`).
Given a directory with no files matching `*.query`:
When `_scan_directory` is called with that pattern,
Then no files should be returned in the result.
"""
sample_query.local_path = temp_directory
sample_query.include_patterns = {"*.query"} # Nonexistent extension
result = _scan_directory(temp_directory, query=sample_query)
assert result is not None, "Expected a valid directory node structure"
files = _extract_files_content(query=sample_query, node=result)
assert len(files) == 0, "Should not find any files matching *.query"
assert result["type"] == "directory"
assert result["file_count"] == 0, "No files counted with this pattern"
assert result["dir_count"] == 0
assert len(result["children"]) == 0
@pytest.mark.parametrize("include_pattern", ["src/*", "src/**", "src*"])
def test_include_src_patterns(temp_directory: Path, sample_query: ParsedQuery, include_pattern: str) -> None:
"""
Test including files under the `src` directory with various patterns.
Given a directory containing `src` with subfiles:
When `include_patterns` is set to `src/*`, `src/**`, or `src*`,
Then `_scan_directory` should include the correct files under `src`.
Note: Windows is not supported; paths are converted to Unix-style for validation.
"""
sample_query.local_path = temp_directory
sample_query.include_patterns = {include_pattern}
result = _scan_directory(temp_directory, query=sample_query)
assert result is not None, "Expected a valid directory node structure"
files = _extract_files_content(query=sample_query, node=result)
# Convert Windows paths to Unix-style
file_paths = {f["path"].replace("\\", "/") for f in files}
expected_paths = {
"src/subfile1.txt",
"src/subfile2.py",
"src/subdir/file_subdir.txt",
"src/subdir/file_subdir.py",
}
assert file_paths == expected_paths, "Missing or unexpected files in result"
def test_run_ingest_query(temp_directory: Path, sample_query: ParsedQuery) -> None:
"""
Test `run_ingest_query` to ensure it processes the directory and returns expected results.
Given a directory with .txt and .py files:
When `run_ingest_query` is invoked,
Then it should produce a summary string listing the files analyzed and a combined content string.
"""
sample_query.local_path = temp_directory
sample_query.subpath = "/"
sample_query.type = None
summary, _, content = run_ingest_query(sample_query)
assert "Repository: test_user/test_repo" in summary
assert "Files analyzed: 8" in summary
# Check presence of key files in the content
assert "src/subfile1.txt" in content
assert "src/subfile2.py" in content
assert "src/subdir/file_subdir.txt" in content
assert "src/subdir/file_subdir.py" in content
assert "file1.txt" in content
assert "file2.py" in content
assert "dir1/file_dir1.txt" in content
assert "dir2/file_dir2.txt" in content
# TODO: Additional tests:
# - Multiple include patterns, e.g. ["*.txt", "*.py"] or ["/src/*", "*.txt"].
# - Edge cases with weird file names or deep subdirectory structures.

View file

@ -1,5 +1,5 @@
"""
Tests for the `repository_clone` module.
Tests for the `cloning` module.
These tests cover various scenarios for cloning repositories, verifying that the appropriate Git commands are invoked
and handling edge cases such as nonexistent URLs, timeouts, redirects, and specific commits or branches.
@ -12,8 +12,8 @@ from unittest.mock import AsyncMock, patch
import pytest
from gitingest.cloning import CloneConfig, _check_repo_exists, clone_repo
from gitingest.exceptions import AsyncTimeoutError
from gitingest.repository_clone import CloneConfig, _check_repo_exists, clone_repo
@pytest.mark.asyncio
@ -32,8 +32,8 @@ async def test_clone_repo_with_commit() -> None:
branch="main",
)
with patch("gitingest.repository_clone._check_repo_exists", return_value=True) as mock_check:
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True) as mock_check:
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
mock_process = AsyncMock()
mock_process.communicate.return_value = (b"output", b"error")
mock_exec.return_value = mock_process
@ -60,8 +60,8 @@ async def test_clone_repo_without_commit() -> None:
branch="main",
)
with patch("gitingest.repository_clone._check_repo_exists", return_value=True) as mock_check:
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True) as mock_check:
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
mock_process = AsyncMock()
mock_process.communicate.return_value = (b"output", b"error")
mock_exec.return_value = mock_process
@ -87,7 +87,7 @@ async def test_clone_repo_nonexistent_repository() -> None:
commit=None,
branch="main",
)
with patch("gitingest.repository_clone._check_repo_exists", return_value=False) as mock_check:
with patch("gitingest.cloning._check_repo_exists", return_value=False) as mock_check:
with pytest.raises(ValueError, match="Repository not found"):
await clone_repo(clone_config)
@ -135,14 +135,13 @@ async def test_clone_repo_with_custom_branch() -> None:
Then the repository should be cloned shallowly to that branch.
"""
clone_config = CloneConfig(url="https://github.com/user/repo", local_path="/tmp/repo", branch="feature-branch")
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
await clone_repo(clone_config)
mock_exec.assert_called_once_with(
"git",
"clone",
"--recurse-submodules",
"--single-branch",
"--depth=1",
"--branch",
@ -165,8 +164,8 @@ async def test_git_command_failure() -> None:
url="https://github.com/user/repo",
local_path="/tmp/repo",
)
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", side_effect=RuntimeError("Git command failed")):
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", side_effect=RuntimeError("Git command failed")):
with pytest.raises(RuntimeError, match="Git command failed"):
await clone_repo(clone_config)
@ -185,14 +184,13 @@ async def test_clone_repo_default_shallow_clone() -> None:
local_path="/tmp/repo",
)
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
await clone_repo(clone_config)
mock_exec.assert_called_once_with(
"git",
"clone",
"--recurse-submodules",
"--single-branch",
"--depth=1",
clone_config.url,
@ -214,14 +212,12 @@ async def test_clone_repo_commit_without_branch() -> None:
local_path="/tmp/repo",
commit="a" * 40, # Simulating a valid commit hash
)
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
await clone_repo(clone_config)
assert mock_exec.call_count == 2 # Clone and checkout calls
mock_exec.assert_any_call(
"git", "clone", "--recurse-submodules", "--single-branch", clone_config.url, clone_config.local_path
)
mock_exec.assert_any_call("git", "clone", "--single-branch", clone_config.url, clone_config.local_path)
mock_exec.assert_any_call("git", "-C", clone_config.local_path, "checkout", clone_config.commit)
@ -278,8 +274,8 @@ async def test_clone_repo_with_timeout() -> None:
"""
clone_config = CloneConfig(url="https://github.com/user/repo", local_path="/tmp/repo")
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
mock_exec.side_effect = asyncio.TimeoutError
with pytest.raises(AsyncTimeoutError, match="Operation timed out after"):
await clone_repo(clone_config)
@ -324,14 +320,13 @@ async def test_clone_branch_with_slashes(tmp_path):
local_path = tmp_path / "gitingest"
clone_config = CloneConfig(url=repo_url, local_path=str(local_path), branch=branch_name)
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
await clone_repo(clone_config)
mock_exec.assert_called_once_with(
"git",
"clone",
"--recurse-submodules",
"--single-branch",
"--depth=1",
"--branch",
@ -356,8 +351,8 @@ async def test_clone_repo_creates_parent_directory(tmp_path: Path) -> None:
local_path=str(nested_path),
)
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
await clone_repo(clone_config)
# Verify parent directory was created
@ -367,7 +362,6 @@ async def test_clone_repo_creates_parent_directory(tmp_path: Path) -> None:
mock_exec.assert_called_once_with(
"git",
"clone",
"--recurse-submodules",
"--single-branch",
"--depth=1",
clone_config.url,
@ -386,15 +380,14 @@ async def test_clone_with_specific_subpath() -> None:
"""
clone_config = CloneConfig(url="https://github.com/user/repo", local_path="/tmp/repo", subpath="src/docs")
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
await clone_repo(clone_config)
# Verify the clone command includes sparse checkout flags
mock_exec.assert_any_call(
"git",
"clone",
"--recurse-submodules",
"--single-branch",
"--filter=blob:none",
"--sparse",
@ -426,15 +419,14 @@ async def test_clone_with_commit_and_subpath() -> None:
subpath="src/docs",
)
with patch("gitingest.repository_clone._check_repo_exists", return_value=True):
with patch("gitingest.repository_clone._run_command", new_callable=AsyncMock) as mock_exec:
with patch("gitingest.cloning._check_repo_exists", return_value=True):
with patch("gitingest.cloning._run_command", new_callable=AsyncMock) as mock_exec:
await clone_repo(clone_config)
# Verify the clone command includes sparse checkout flags
mock_exec.assert_any_call(
"git",
"clone",
"--recurse-submodules",
"--single-branch",
"--filter=blob:none",
"--sparse",