mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-28 03:19:59 +00:00
fix: prevent SurrealDB injection via order_by and unparameterized queries
- Add allowlist validation for order_by param in notebooks endpoint - Parameterize session_id query in source_chat router - Add regex validation in base.py get_all() order_by parameter - Convert async_migrate bump/lower_version to parameterized queries
This commit is contained in:
parent
6274358b21
commit
e5b253b11d
4 changed files with 71 additions and 5 deletions
|
|
@ -24,13 +24,38 @@ async def get_notebooks(
|
|||
):
|
||||
"""Get all notebooks with optional filtering and ordering."""
|
||||
try:
|
||||
# Validate order_by against allowlist to prevent SurrealQL injection
|
||||
allowed_fields = {"name", "created", "updated"}
|
||||
allowed_directions = {"asc", "desc"}
|
||||
|
||||
parts = order_by.strip().lower().split()
|
||||
if len(parts) == 1:
|
||||
if parts[0] not in allowed_fields:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid order_by field: '{order_by}'. Allowed fields: {', '.join(sorted(allowed_fields))}",
|
||||
)
|
||||
validated_order_by = parts[0]
|
||||
elif len(parts) == 2:
|
||||
if parts[0] not in allowed_fields or parts[1] not in allowed_directions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid order_by: '{order_by}'. Allowed fields: {', '.join(sorted(allowed_fields))}. Allowed directions: asc, desc",
|
||||
)
|
||||
validated_order_by = f"{parts[0]} {parts[1]}"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid order_by format: '{order_by}'. Expected 'field' or 'field direction'",
|
||||
)
|
||||
|
||||
# Build the query with counts
|
||||
query = f"""
|
||||
SELECT *,
|
||||
count(<-reference.in) as source_count,
|
||||
count(<-artifact.in) as note_count
|
||||
FROM notebook
|
||||
ORDER BY {order_by}
|
||||
ORDER BY {validated_order_by}
|
||||
"""
|
||||
|
||||
result = await repo_query(query)
|
||||
|
|
@ -52,6 +77,8 @@ async def get_notebooks(
|
|||
)
|
||||
for nb in result
|
||||
]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching notebooks: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -155,7 +155,9 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc
|
|||
if session_id_raw:
|
||||
session_id = str(session_id_raw)
|
||||
|
||||
session_result = await repo_query(f"SELECT * FROM {session_id_raw}")
|
||||
session_result = await repo_query(
|
||||
"SELECT * FROM $id", {"id": ensure_record_id(session_id)}
|
||||
)
|
||||
if session_result and len(session_result) > 0:
|
||||
session_data = session_result[0]
|
||||
|
||||
|
|
|
|||
|
|
@ -223,7 +223,8 @@ async def bump_version() -> None:
|
|||
new_version = current_version + 1
|
||||
|
||||
await repo_query(
|
||||
f"CREATE _sbl_migrations:{new_version} SET version = {new_version}, applied_at = time::now();",
|
||||
"CREATE type::thing('_sbl_migrations', $version) SET version = $version, applied_at = time::now();",
|
||||
{"version": new_version},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -231,4 +232,7 @@ async def lower_version() -> None:
|
|||
"""Lower the version by removing the latest entry from migrations table."""
|
||||
current_version = await get_latest_version()
|
||||
if current_version > 0:
|
||||
await repo_query(f"DELETE _sbl_migrations:{current_version};")
|
||||
await repo_query(
|
||||
"DELETE type::thing('_sbl_migrations', $version);",
|
||||
{"version": current_version},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,40 @@ class ObjectModel(BaseModel):
|
|||
"get_all() must be called from a specific model class"
|
||||
)
|
||||
if order_by:
|
||||
query = f"SELECT * FROM {table_name} ORDER BY {order_by}"
|
||||
# Validate order_by to prevent SurrealQL injection
|
||||
# Supports: "field", "field direction", "field1 direction, field2 direction"
|
||||
import re
|
||||
|
||||
allowed_field_pattern = re.compile(r"^[a-z_][a-z0-9_]*$")
|
||||
allowed_directions = {"asc", "desc"}
|
||||
|
||||
clauses = [c.strip() for c in order_by.split(",")]
|
||||
validated_clauses = []
|
||||
for clause in clauses:
|
||||
parts = clause.strip().split()
|
||||
if len(parts) == 1:
|
||||
if not allowed_field_pattern.match(parts[0].lower()):
|
||||
raise InvalidInputError(
|
||||
f"Invalid order_by field: '{parts[0]}'"
|
||||
)
|
||||
validated_clauses.append(parts[0].lower())
|
||||
elif len(parts) == 2:
|
||||
if not allowed_field_pattern.match(
|
||||
parts[0].lower()
|
||||
) or parts[1].lower() not in allowed_directions:
|
||||
raise InvalidInputError(
|
||||
f"Invalid order_by clause: '{clause.strip()}'"
|
||||
)
|
||||
validated_clauses.append(
|
||||
f"{parts[0].lower()} {parts[1].lower()}"
|
||||
)
|
||||
else:
|
||||
raise InvalidInputError(
|
||||
f"Invalid order_by clause: '{clause.strip()}'"
|
||||
)
|
||||
|
||||
validated_order_by = ", ".join(validated_clauses)
|
||||
query = f"SELECT * FROM {table_name} ORDER BY {validated_order_by}"
|
||||
else:
|
||||
query = f"SELECT * FROM {table_name}"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue