supermemory/data-generator/planner.py
Dhravya 771be5cef8 fix: apply review feedback — fix double data/ prefix, semaphore bug, resume bug, consolidate duplicated code
- Fix worker.py writing to data/data/ instead of data/ (critical path bug)
- Fix semaphore recreation on every call due to checking _value instead of capacity
- Fix questions.py resume returning raw string instead of list[dict]
- Fix prompts/file_gen.py reading 'summary' instead of 'brief' from manifest
- Extract shared unwrap_json_list() and truncate_to_tokens() into utils.py
- Remove redundant validation report writes in generate.py
- Remove unused imports and dependencies
- Fix f-string logger calls to use lazy %s formatting
- Move calendar import to top-level in validator.py
- Use write_text() for atomic writes in repair_files()
- Strengthen test_resume_support to assert return type
2026-04-28 23:49:23 +00:00

524 lines
16 KiB
Python

"""Planning module for the eval corpus data generator.
Handles three sequential phases:
Phase 1 — Scenario Brief (SCENARIO.md)
Phase 2 — Fact Registry (facts.json)
Phase 3 — File Manifest (manifest.json)
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
from utils import (
DEFAULT_MODEL,
FAST_MODEL,
count_tokens,
llm_call,
llm_call_json,
read_text,
unwrap_json_list,
write_json,
write_text,
)
from prompts.scenario_brief import (
SCENARIO_BRIEF_SYSTEM,
format_scenario_brief_prompt,
)
from prompts.fact_registry import (
FACT_REGISTRY_SYSTEM,
format_fact_registry_prompt,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
LARGE_CORPUS_THRESHOLD = 50
CHUNK_SIZE = 30
# ---------------------------------------------------------------------------
# Manifest prompt templates
# ---------------------------------------------------------------------------
MANIFEST_SYSTEM = """\
You are a corpus architect. Given a scenario brief and fact registry you must produce
a JSON manifest describing every file in the corpus. Each entry specifies exactly
what a downstream file-generator worker needs to produce that file.
Rules:
- file_id values are sequential: f001, f002, …
- target_tokens [min, max] must both be in [5000, 10000]
- locked_facts lists reference fact IDs from the registry — be exhaustive
- cross_references must be bidirectional (if f001 references f002, f002 references f001)
- cluster_hint groups related files (e.g. "legal", "medical_records", "travel")
- brief is 2-3 sentences describing the file's content
- authors is a list of person IDs from the fact registry
"""
MANIFEST_PROMPT = """\
## Task
Generate a file manifest (JSON array) for the corpus described below.
## Scenario Brief
{scenario_brief}
## Fact Registry
```json
{fact_registry}
```
## Requirements
Generate exactly {file_count} file entries as a JSON array. Each entry must have:
- **file_id**: sequential ID (f001, f002, …)
- **path**: relative path under data/ (e.g. "data/emails/booking_confirmation.eml")
- **format**: document format (markdown_prose, email_thread, transcript, legal_contract, \
lab_report, slack_export, csv_data, json_structured, etc.)
- **authors**: list of person IDs from the fact registry
- **date**: ISO 8601 date (YYYY-MM-DD)
- **target_tokens**: [min, max] both within [5000, 10000]
- **locked_facts**: list of fact IDs from the registry that MUST appear in this file
- **cross_references**: list of other file_ids this file references or is referenced by
- **cluster_hint**: group name for related files
- **brief**: 2-3 sentence description of contents
- **tone**: formal/casual/clinical/technical/etc.
- **format_notes**: specific formatting requirements
Return ONLY a JSON array — no wrapper object, no markdown fences.
"""
OUTLINE_PROMPT = """\
## Task
You are planning a large corpus of {file_count} files. To manage complexity, first
produce a department/section outline that organizes the files into logical groups.
## Scenario Brief (Summary)
{scenario_summary}
## Requirements
Return a JSON object with this structure:
```json
{{
"sections": [
{{
"name": "Section Name",
"cluster_hint": "section_slug",
"file_count": 15,
"description": "What files in this section cover"
}}
]
}}
```
Rules:
- Total file_count across all sections must equal exactly {file_count}
- Each section should have roughly {chunk_size} files (±10)
- Section names should be descriptive (e.g. "Legal Documents", "Medical Records")
- cluster_hint must be a URL-safe slug
"""
SECTION_MANIFEST_PROMPT = """\
## Task
Generate file manifest entries for the "{section_name}" section of the corpus.
## Scenario Brief
{scenario_brief}
## Fact Registry
```json
{fact_registry}
```
## Section Details
- **Section**: {section_name} ({section_description})
- **Cluster Hint**: {cluster_hint}
- **File Count**: {section_file_count}
- **Starting file_id**: f{start_id:03d}
## Requirements
Generate exactly {section_file_count} file entries as a JSON array. Each entry must have:
- **file_id**: sequential starting from f{start_id:03d}
- **path**: relative path under data/ (e.g. "data/{cluster_hint}/filename.ext")
- **format**: document format
- **authors**: list of person IDs from the fact registry
- **date**: ISO 8601 date (YYYY-MM-DD)
- **target_tokens**: [min, max] both within [5000, 10000]
- **locked_facts**: list of fact IDs from the registry that MUST appear in this file
- **cross_references**: list of other file_ids this file references (use IDs from any section)
- **cluster_hint**: "{cluster_hint}"
- **brief**: 2-3 sentence description of contents
- **tone**: formal/casual/clinical/technical/etc.
- **format_notes**: specific formatting requirements
Return ONLY a JSON array — no wrapper object, no markdown fences.
"""
# ---------------------------------------------------------------------------
# Validation helpers
# ---------------------------------------------------------------------------
def _validate_fact_registry(registry: dict[str, Any]) -> dict[str, Any]:
"""Validate the fact registry has the expected top-level keys.
Returns the registry unchanged if valid, raises ValueError otherwise.
"""
required_keys = {"people", "organizations", "dates"}
missing = required_keys - set(registry.keys())
if missing:
raise ValueError(f"Fact registry missing required keys: {missing}")
# Validate people entries have 'id' fields
for person in registry.get("people", []):
if "id" not in person:
raise ValueError(f"Person entry missing 'id': {person}")
return registry
def _validate_manifest_entry(entry: dict[str, Any], idx: int) -> list[str]:
"""Validate a single manifest entry. Returns list of warnings (empty = OK)."""
warnings: list[str] = []
required_fields = {
"file_id", "path", "format", "authors", "date",
"target_tokens", "locked_facts", "cross_references",
"cluster_hint", "brief", "tone", "format_notes",
}
missing = required_fields - set(entry.keys())
if missing:
warnings.append(f"Entry {idx} missing fields: {missing}")
# Validate target_tokens range
tokens = entry.get("target_tokens")
if isinstance(tokens, list) and len(tokens) == 2:
lo, hi = tokens
if not (5000 <= lo <= 10000 and 5000 <= hi <= 10000):
warnings.append(
f"Entry {idx} target_tokens {tokens} outside [5000, 10000]"
)
if lo > hi:
warnings.append(f"Entry {idx} target_tokens min > max: {tokens}")
elif tokens is not None:
warnings.append(f"Entry {idx} target_tokens malformed: {tokens}")
return warnings
def _validate_manifest(manifest: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Validate all manifest entries. Logs warnings but returns the manifest."""
all_warnings: list[str] = []
for idx, entry in enumerate(manifest):
all_warnings.extend(_validate_manifest_entry(entry, idx))
if all_warnings:
for w in all_warnings:
logger.warning("Manifest validation: %s", w)
return manifest
def _renumber_manifest(manifest: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Re-number file_ids sequentially (f001, f002, …) and update cross_references."""
# Build old-id → new-id mapping
id_map: dict[str, str] = {}
for idx, entry in enumerate(manifest):
old_id = entry.get("file_id", "")
new_id = f"f{idx + 1:03d}"
id_map[old_id] = new_id
entry["file_id"] = new_id
# Remap cross_references
for entry in manifest:
refs = entry.get("cross_references", [])
entry["cross_references"] = [
id_map.get(ref, ref) for ref in refs
]
return manifest
# ---------------------------------------------------------------------------
# Phase 1: Scenario Brief
# ---------------------------------------------------------------------------
async def generate_scenario_brief(
scenario_block: str,
file_count: int,
output_dir: Path,
model: str = DEFAULT_MODEL,
) -> str:
"""Phase 1: Generate SCENARIO.md. Returns the brief text."""
output_path = output_dir / "SCENARIO.md"
# Resume support: skip if already exists
if output_path.exists():
logger.info("Phase 1 skipped — SCENARIO.md already exists")
return read_text(output_path)
logger.info("Phase 1: Generating scenario brief …")
prompt = format_scenario_brief_prompt(scenario_block, file_count)
brief = await llm_call(
prompt,
model=model,
system=SCENARIO_BRIEF_SYSTEM,
max_tokens=16384,
)
write_text(output_path, brief)
logger.info(
"Phase 1 complete — SCENARIO.md written (%d tokens)", count_tokens(brief)
)
return brief
# ---------------------------------------------------------------------------
# Phase 2: Fact Registry
# ---------------------------------------------------------------------------
async def extract_fact_registry(
scenario_brief: str,
output_dir: Path,
model: str = DEFAULT_MODEL,
) -> dict:
"""Phase 2: Extract facts.json from SCENARIO.md. Returns the registry dict."""
output_path = output_dir / "facts.json"
# Resume support: skip if already exists
if output_path.exists():
logger.info("Phase 2 skipped — facts.json already exists")
data = json.loads(read_text(output_path))
return data
logger.info("Phase 2: Extracting fact registry …")
prompt = format_fact_registry_prompt(scenario_brief)
registry = await llm_call_json(
prompt,
model=model,
system=FACT_REGISTRY_SYSTEM,
max_tokens=16384,
)
# Handle case where registry is wrapped in a key
if isinstance(registry, dict) and len(registry) == 1:
key = next(iter(registry))
if isinstance(registry[key], dict):
# Might be double-wrapped; check if inner dict has expected keys
inner = registry[key]
if "people" in inner or "organizations" in inner:
registry = inner
_validate_fact_registry(registry)
write_json(output_path, registry)
fact_count = sum(
len(v) for v in registry.values() if isinstance(v, list)
)
logger.info("Phase 2 complete — facts.json written (%d fact entries)", fact_count)
return registry
# ---------------------------------------------------------------------------
# Phase 3: File Manifest
# ---------------------------------------------------------------------------
async def _generate_small_manifest(
scenario_brief: str,
fact_registry: dict,
file_count: int,
model: str,
) -> list[dict]:
"""Generate manifest in a single LLM call (≤50 files)."""
prompt = MANIFEST_PROMPT.format(
scenario_brief=scenario_brief,
fact_registry=json.dumps(fact_registry, indent=2),
file_count=file_count,
)
result = await llm_call_json(
prompt,
model=model,
system=MANIFEST_SYSTEM,
max_tokens=16384,
)
return unwrap_json_list(result, expected_keys=("files", "manifest", "entries"))
async def _generate_large_manifest(
scenario_brief: str,
fact_registry: dict,
file_count: int,
model: str,
) -> list[dict]:
"""Generate manifest in chunks for large corpora (>50 files)."""
# Summarize the brief if it's very long to keep section prompts under limit
brief_tokens = count_tokens(scenario_brief)
if brief_tokens > 6000:
scenario_summary = scenario_brief[:12000] + "\n\n[… truncated for outline …]"
else:
scenario_summary = scenario_brief
# Step 1: Generate section outline
logger.info("Phase 3a: Generating section outline for %d files …", file_count)
outline_prompt = OUTLINE_PROMPT.format(
file_count=file_count,
scenario_summary=scenario_summary,
chunk_size=CHUNK_SIZE,
)
outline = await llm_call_json(
outline_prompt,
model=model,
system=MANIFEST_SYSTEM,
max_tokens=4096,
)
sections = outline.get("sections", [])
if not sections:
raise ValueError("Outline generation returned no sections")
# Adjust section file counts to match total exactly
total_assigned = sum(s["file_count"] for s in sections)
if total_assigned != file_count:
diff = file_count - total_assigned
# Distribute difference across sections
sections[-1]["file_count"] += diff
logger.warning(
"Adjusted last section file_count by %d to match total %d",
diff,
file_count,
)
logger.info(
"Outline has %d sections: %s",
len(sections),
", ".join(f'{s["name"]}({s["file_count"]})' for s in sections),
)
# Step 2: Generate manifest for each section
all_entries: list[dict] = []
current_start_id = 1
for section in sections:
section_name = section["name"]
section_file_count = section["file_count"]
cluster_hint = section.get("cluster_hint", section_name.lower().replace(" ", "_"))
section_description = section.get("description", "")
logger.info(
"Phase 3b: Generating %d entries for section '%s' (starting f%03d) …",
section_file_count,
section_name,
current_start_id,
)
section_prompt = SECTION_MANIFEST_PROMPT.format(
section_name=section_name,
scenario_brief=scenario_brief,
fact_registry=json.dumps(fact_registry, indent=2),
section_description=section_description,
cluster_hint=cluster_hint,
section_file_count=section_file_count,
start_id=current_start_id,
)
result = await llm_call_json(
section_prompt,
model=model,
system=MANIFEST_SYSTEM,
max_tokens=16384,
)
entries = unwrap_json_list(result, expected_keys=("files", "manifest", "entries"))
all_entries.extend(entries)
current_start_id += section_file_count
return all_entries
async def generate_manifest(
scenario_brief: str,
fact_registry: dict,
file_count: int,
output_dir: Path,
model: str = DEFAULT_MODEL,
) -> list[dict]:
"""Phase 3: Generate manifest.json. Returns list of file entries."""
output_path = output_dir / "manifest.json"
# Resume support: skip if already exists
if output_path.exists():
logger.info("Phase 3 skipped — manifest.json already exists")
data = json.loads(read_text(output_path))
if isinstance(data, dict) and "files" in data:
return data["files"]
return data
logger.info("Phase 3: Generating manifest for %d files …", file_count)
if file_count <= LARGE_CORPUS_THRESHOLD:
manifest = await _generate_small_manifest(
scenario_brief, fact_registry, file_count, model
)
else:
manifest = await _generate_large_manifest(
scenario_brief, fact_registry, file_count, model
)
# Re-number sequentially and fix cross-references
manifest = _renumber_manifest(manifest)
_validate_manifest(manifest)
write_json(output_path, manifest)
logger.info("Phase 3 complete — manifest.json written (%d entries)", len(manifest))
return manifest
# ---------------------------------------------------------------------------
# Orchestrator
# ---------------------------------------------------------------------------
async def run_planning(
scenario_block: str,
file_count: int,
output_dir: Path,
model: str = DEFAULT_MODEL,
) -> tuple[str, dict, list[dict]]:
"""Run all three planning phases sequentially.
Returns (brief, facts, manifest).
"""
output_dir.mkdir(parents=True, exist_ok=True)
brief = await generate_scenario_brief(scenario_block, file_count, output_dir, model)
facts = await extract_fact_registry(brief, output_dir, model)
manifest = await generate_manifest(brief, facts, file_count, output_dir, model)
return brief, facts, manifest