fix: prevent SurrealDB injection via order_by and unparameterized queries

- Add allowlist validation for order_by param in notebooks endpoint
- Parameterize session_id query in source_chat router
- Add regex validation in base.py get_all() order_by parameter
- Convert async_migrate bump/lower_version to parameterized queries
This commit is contained in:
Luis Novo 2026-04-07 07:51:25 -03:00
parent 6274358b21
commit e5b253b11d
4 changed files with 71 additions and 5 deletions

View file

@ -24,13 +24,38 @@ async def get_notebooks(
):
"""Get all notebooks with optional filtering and ordering."""
try:
# Validate order_by against allowlist to prevent SurrealQL injection
allowed_fields = {"name", "created", "updated"}
allowed_directions = {"asc", "desc"}
parts = order_by.strip().lower().split()
if len(parts) == 1:
if parts[0] not in allowed_fields:
raise HTTPException(
status_code=400,
detail=f"Invalid order_by field: '{order_by}'. Allowed fields: {', '.join(sorted(allowed_fields))}",
)
validated_order_by = parts[0]
elif len(parts) == 2:
if parts[0] not in allowed_fields or parts[1] not in allowed_directions:
raise HTTPException(
status_code=400,
detail=f"Invalid order_by: '{order_by}'. Allowed fields: {', '.join(sorted(allowed_fields))}. Allowed directions: asc, desc",
)
validated_order_by = f"{parts[0]} {parts[1]}"
else:
raise HTTPException(
status_code=400,
detail=f"Invalid order_by format: '{order_by}'. Expected 'field' or 'field direction'",
)
# Build the query with counts
query = f"""
SELECT *,
count(<-reference.in) as source_count,
count(<-artifact.in) as note_count
FROM notebook
ORDER BY {order_by}
ORDER BY {validated_order_by}
"""
result = await repo_query(query)
@ -52,6 +77,8 @@ async def get_notebooks(
)
for nb in result
]
except HTTPException:
raise
except Exception as e:
logger.error(f"Error fetching notebooks: {str(e)}")
raise HTTPException(

View file

@ -155,7 +155,9 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc
if session_id_raw:
session_id = str(session_id_raw)
session_result = await repo_query(f"SELECT * FROM {session_id_raw}")
session_result = await repo_query(
"SELECT * FROM $id", {"id": ensure_record_id(session_id)}
)
if session_result and len(session_result) > 0:
session_data = session_result[0]

View file

@ -223,7 +223,8 @@ async def bump_version() -> None:
new_version = current_version + 1
await repo_query(
f"CREATE _sbl_migrations:{new_version} SET version = {new_version}, applied_at = time::now();",
"CREATE type::thing('_sbl_migrations', $version) SET version = $version, applied_at = time::now();",
{"version": new_version},
)
@ -231,4 +232,7 @@ async def lower_version() -> None:
"""Lower the version by removing the latest entry from migrations table."""
current_version = await get_latest_version()
if current_version > 0:
await repo_query(f"DELETE _sbl_migrations:{current_version};")
await repo_query(
"DELETE type::thing('_sbl_migrations', $version);",
{"version": current_version},
)

View file

@ -48,7 +48,40 @@ class ObjectModel(BaseModel):
"get_all() must be called from a specific model class"
)
if order_by:
query = f"SELECT * FROM {table_name} ORDER BY {order_by}"
# Validate order_by to prevent SurrealQL injection
# Supports: "field", "field direction", "field1 direction, field2 direction"
import re
allowed_field_pattern = re.compile(r"^[a-z_][a-z0-9_]*$")
allowed_directions = {"asc", "desc"}
clauses = [c.strip() for c in order_by.split(",")]
validated_clauses = []
for clause in clauses:
parts = clause.strip().split()
if len(parts) == 1:
if not allowed_field_pattern.match(parts[0].lower()):
raise InvalidInputError(
f"Invalid order_by field: '{parts[0]}'"
)
validated_clauses.append(parts[0].lower())
elif len(parts) == 2:
if not allowed_field_pattern.match(
parts[0].lower()
) or parts[1].lower() not in allowed_directions:
raise InvalidInputError(
f"Invalid order_by clause: '{clause.strip()}'"
)
validated_clauses.append(
f"{parts[0].lower()} {parts[1].lower()}"
)
else:
raise InvalidInputError(
f"Invalid order_by clause: '{clause.strip()}'"
)
validated_order_by = ", ".join(validated_clauses)
query = f"SELECT * FROM {table_name} ORDER BY {validated_order_by}"
else:
query = f"SELECT * FROM {table_name}"