supermemory/data-generator/worker.py
Dhravya 771be5cef8 fix: apply review feedback — fix double data/ prefix, semaphore bug, resume bug, consolidate duplicated code
- Fix worker.py writing to data/data/ instead of data/ (critical path bug)
- Fix semaphore recreation on every call due to checking _value instead of capacity
- Fix questions.py resume returning raw string instead of list[dict]
- Fix prompts/file_gen.py reading 'summary' instead of 'brief' from manifest
- Extract shared unwrap_json_list() and truncate_to_tokens() into utils.py
- Remove redundant validation report writes in generate.py
- Remove unused imports and dependencies
- Fix f-string logger calls to use lazy %s formatting
- Move calendar import to top-level in validator.py
- Use write_text() for atomic writes in repair_files()
- Strengthen test_resume_support to assert return type
2026-04-28 23:49:23 +00:00

525 lines
18 KiB
Python

"""Phase 5: Parallel file generation workers.
Takes clusters of file entries, a fact registry shard, and optionally
already-generated files for cross-reference context. Generates each file
sequentially within a cluster, passing previously generated files as context.
Clusters at the same topological level run in parallel.
"""
from __future__ import annotations
import asyncio
import logging
import time
from pathlib import Path
from typing import Any
from prompts.file_gen import format_file_gen_prompt, format_retry_prompt
from utils import (
DEFAULT_MODEL,
GenerationLog,
count_tokens,
llm_call,
read_text,
truncate_to_tokens,
write_text,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MAX_CONTEXT_TOKENS_PER_FILE = 3000
MAX_TOTAL_CONTEXT_TOKENS = 15000
MAX_RETRIES = 2
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_cluster_file_ids(cluster: Any) -> list[str]:
"""Extract file IDs from a cluster object's ``file_entries`` list."""
if hasattr(cluster, "file_entries"):
return [e.get("file_id", "") for e in cluster.file_entries if e.get("file_id")]
raise TypeError(
f"Cluster object has no 'file_entries' attribute: {type(cluster)}"
)
def _build_context_files(
file_entry: dict,
generated: dict[str, str],
manifest_entries: dict[str, dict],
) -> dict[str, str]:
"""Build the context dict for cross-referenced files.
Priority:
- If the referenced file has been generated: include its content (up to
``MAX_CONTEXT_TOKENS_PER_FILE`` tokens).
- If not yet generated: the brief from the manifest is used later by the
prompt builder (via ``manifest_entries``), so we don't duplicate it here.
If total context would exceed ``MAX_TOTAL_CONTEXT_TOKENS``, we keep the
files with the most cross-references first and drop the rest.
"""
cross_refs: list[str] = file_entry.get("cross_references", [])
if not cross_refs:
return {}
# Collect available generated content for referenced files
candidates: list[tuple[str, str]] = []
for ref_id in cross_refs:
content = generated.get(ref_id)
if content is not None:
truncated = truncate_to_tokens(content, MAX_CONTEXT_TOKENS_PER_FILE)
candidates.append((ref_id, truncated))
# Sort by number of cross-references each candidate has (most connected first)
def _xref_count(file_id: str) -> int:
entry = manifest_entries.get(file_id, {})
return len(entry.get("cross_references", []))
candidates.sort(key=lambda pair: _xref_count(pair[0]), reverse=True)
# Enforce total context budget
context: dict[str, str] = {}
total_tokens = 0
for fid, content in candidates:
tok = count_tokens(content)
if total_tokens + tok > MAX_TOTAL_CONTEXT_TOKENS:
# Try to fit a smaller portion
remaining = MAX_TOTAL_CONTEXT_TOKENS - total_tokens
if remaining > 200:
content = truncate_to_tokens(content, remaining)
context[fid] = content
break
context[fid] = content
total_tokens += tok
return context
def _validate_content(
content: str,
file_entry: dict,
fact_shard: dict,
) -> list[str]:
"""Validate generated content. Returns a list of issue descriptions (empty = valid)."""
issues: list[str] = []
token_count = count_tokens(content)
# --- Token range check ---
target_tokens = file_entry.get("target_tokens", [5000, 10000])
target_min = target_tokens[0] if isinstance(target_tokens, list) else 5000
target_max = target_tokens[1] if isinstance(target_tokens, list) else 10000
if token_count < target_min:
issues.append(
f"Too short: {token_count:,} tokens (minimum {target_min:,}). "
f"Add more realistic content, filler, and noise."
)
elif token_count > target_max * 1.3:
# Allow 30% overshoot before flagging — slight overshoot is better than
# being too short.
issues.append(
f"Too long: {token_count:,} tokens (maximum ~{target_max:,}). "
f"Trim some filler while keeping all locked facts."
)
# --- Locked facts spot-check ---
locked_ids = set(file_entry.get("locked_facts", []))
if locked_ids:
content_lower = content.lower()
missing_facts: list[str] = []
for category in ("financial", "dates", "references", "locations", "domain_facts"):
for fact in fact_shard.get(category, []):
if fact.get("id") not in locked_ids:
continue
# Determine key values to check in the content
key_values = _extract_key_values(fact, category)
found_any = any(
kv.lower() in content_lower for kv in key_values if kv
)
if not found_any and key_values:
missing_facts.append(
f"{fact['id']} (expected one of: {key_values})"
)
if missing_facts:
issues.append(
"Missing locked facts — the following facts were not found "
"in the generated content:\n "
+ "\n ".join(missing_facts)
)
return issues
def _extract_key_values(fact: dict, category: str) -> list[str]:
"""Extract the key string values from a fact that should appear in the document."""
values: list[str] = []
if category == "financial":
val = fact.get("value", "")
if val:
values.append(val)
elif category == "dates":
date_val = fact.get("date", "")
if date_val:
values.append(date_val)
time_val = fact.get("time", "")
if time_val:
values.append(time_val)
elif category == "references":
val = fact.get("value", "")
if val:
values.append(val)
elif category == "locations":
name = fact.get("name", "")
if name:
values.append(name)
addr = fact.get("address", "")
if addr:
values.append(addr)
elif category == "domain_facts":
fact_text = fact.get("fact", "")
if fact_text:
# For domain facts, check for the first significant clause
# (whole fact string may be too long to match literally)
values.append(fact_text)
return values
# ---------------------------------------------------------------------------
# Core generation
# ---------------------------------------------------------------------------
async def generate_file(
file_entry: dict,
fact_shard: dict,
context_files: dict[str, str],
output_dir: Path,
model: str = DEFAULT_MODEL,
gen_log: GenerationLog | None = None,
manifest_entries: dict[str, dict] | None = None,
) -> str:
"""Generate a single file. Returns the generated content.
Args:
file_entry: manifest entry for this file.
fact_shard: relevant portion of fact registry.
context_files: already-generated files this file cross-references
(file_id -> content).
output_dir: base output directory. File is written to
``output_dir / <path>`` (manifest paths already include ``data/``).
model: LLM model to use.
gen_log: optional generation log for tracking.
manifest_entries: full file_id -> manifest entry map (used for
cross-reference briefs of not-yet-generated files).
"""
file_id: str = file_entry.get("file_id", "unknown")
file_path_rel: str = file_entry.get("path", f"data/{file_id}.md")
dest = output_dir / file_path_rel
# --- Resume support ---
if gen_log and gen_log.is_done(file_id):
logger.info("Skipping %s — already done (gen_log)", file_id)
if dest.exists():
return read_text(dest)
# Log says done but file missing — regenerate
logger.warning("%s marked done but file missing, regenerating", file_id)
if dest.exists() and gen_log is None:
logger.info("Skipping %s — file exists on disk", file_id)
return read_text(dest)
# --- Build prompt ---
system_prompt, user_prompt = format_file_gen_prompt(
file_entry=file_entry,
fact_shard=fact_shard,
context_files=context_files,
manifest_entries=manifest_entries or {},
)
# --- Generate with retries ---
content: str = ""
last_issues: list[str] = []
retries_used = 0
t0 = time.monotonic()
for attempt in range(1 + MAX_RETRIES):
try:
if attempt == 0:
content = await llm_call(
user_prompt,
system=system_prompt,
model=model,
max_tokens=16384,
)
else:
# Retry with feedback
retry_prompt = format_retry_prompt(
issues=last_issues,
previous_content=content,
original_prompt=user_prompt,
)
content = await llm_call(
retry_prompt,
system=system_prompt,
model=model,
max_tokens=16384,
)
retries_used = attempt
# Strip any markdown code fences the LLM might have wrapped around output
content = _strip_wrapping_fences(content)
# Validate
last_issues = _validate_content(content, file_entry, fact_shard)
if not last_issues:
break
logger.warning(
"%s attempt %d validation issues: %s",
file_id,
attempt + 1,
last_issues,
)
except Exception as exc:
logger.error("%s attempt %d error: %s", file_id, attempt + 1, exc)
last_issues = [f"Generation error: {exc}"]
if attempt == MAX_RETRIES:
# All retries exhausted — log failure and return whatever we have
elapsed = time.monotonic() - t0
if gen_log:
gen_log.log_file(
file_id,
model=model,
retries=retries_used,
status="failed",
error=str(exc),
elapsed_s=elapsed,
)
logger.error(
"Failed to generate %s after %d attempts: %s",
file_id,
MAX_RETRIES + 1,
exc,
)
return content
elapsed = time.monotonic() - t0
# Even if there are remaining issues after all retries, write the best attempt
status = "ok" if not last_issues else "partial"
if last_issues:
logger.warning(
"%s: writing with unresolved issues after %d retries: %s",
file_id,
retries_used,
last_issues,
)
# Write to disk
write_text(dest, content)
logger.info(
"Generated %s (%d tokens, %d retries, %.1fs) -> %s",
file_id,
count_tokens(content),
retries_used,
elapsed,
dest,
)
# Log
if gen_log:
gen_log.log_file(
file_id,
model=model,
tokens_out=count_tokens(content),
retries=retries_used,
status=status,
error="; ".join(last_issues) if last_issues else None,
elapsed_s=elapsed,
)
return content
def _strip_wrapping_fences(text: str) -> str:
"""Remove markdown code fences that an LLM might wrap around the output."""
stripped = text.strip()
if stripped.startswith("```"):
lines = stripped.split("\n")
# Remove opening fence (e.g. ```markdown, ```text, ```)
if lines[0].startswith("```"):
lines = lines[1:]
# Remove closing fence
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
return "\n".join(lines)
return text
# ---------------------------------------------------------------------------
# Cluster-level generation
# ---------------------------------------------------------------------------
async def generate_cluster(
cluster: Any,
manifest_entries: dict[str, dict],
fact_shard: dict,
output_dir: Path,
context_files: dict[str, str],
model: str = DEFAULT_MODEL,
gen_log: GenerationLog | None = None,
) -> dict[str, str]:
"""Generate all files in a cluster sequentially.
Returns dict of file_id -> content for all generated files.
Each file in the cluster sees previously generated files as context.
Args:
cluster: a Cluster object with ``file_entries`` (list[dict]) and
``level`` (int). Each entry dict must contain a ``file_id`` key.
manifest_entries: file_id -> manifest entry for ALL files.
fact_shard: the fact registry (or relevant shard).
output_dir: base output directory.
context_files: files from dependency clusters (file_id -> content).
model: LLM model to use.
gen_log: optional generation log.
"""
# Merge dependency context with what we generate in this cluster
combined_context: dict[str, str] = dict(context_files)
generated: dict[str, str] = {}
# Extract ordered file IDs from the cluster's file_entries list
file_ids = _get_cluster_file_ids(cluster)
for file_id in file_ids:
entry = manifest_entries.get(file_id)
if entry is None:
logger.warning(
"File %s in cluster but not in manifest — skipping", file_id
)
continue
# Build cross-reference context for this specific file
file_context = _build_context_files(entry, combined_context, manifest_entries)
content = await generate_file(
file_entry=entry,
fact_shard=fact_shard,
context_files=file_context,
output_dir=output_dir,
model=model,
gen_log=gen_log,
manifest_entries=manifest_entries,
)
generated[file_id] = content
combined_context[file_id] = content
return generated
# ---------------------------------------------------------------------------
# Top-level orchestrator
# ---------------------------------------------------------------------------
async def generate_all(
clusters: list[Any],
manifest_entries: dict[str, dict],
output_dir: Path,
model: str = DEFAULT_MODEL,
max_concurrent: int = 10,
gen_log: GenerationLog | None = None,
fallback_fact_registry: dict | None = None,
) -> None:
"""Generate all files across all clusters, respecting topological order.
Clusters at the same ``level`` run in parallel (up to *max_concurrent*).
Clusters at different levels run sequentially (lower levels first).
Each cluster uses its own ``cluster.fact_shard`` (set by the clusterer's
sharding logic). If a cluster has no ``fact_shard`` attribute or it is
empty, *fallback_fact_registry* is used instead.
Args:
clusters: list of Cluster-like objects, **already ordered by level**.
Each should have a ``fact_shard`` attribute (dict) set by the
clusterer.
manifest_entries: file_id -> manifest entry for ALL files.
output_dir: base output directory.
model: LLM model to use.
max_concurrent: maximum number of clusters processed in parallel
within a single level.
gen_log: optional generation log.
fallback_fact_registry: full fact registry used when a cluster has no
``fact_shard``.
"""
# Group clusters by level
levels: dict[int, list[Any]] = {}
for cluster in clusters:
level = getattr(cluster, "level", 0)
levels.setdefault(level, []).append(cluster)
# All generated content so far (shared across levels)
all_generated: dict[str, str] = {}
for level_num in sorted(levels.keys()):
level_clusters = levels[level_num]
logger.info(
"Level %d: processing %d cluster(s) (up to %d concurrent)",
level_num,
len(level_clusters),
max_concurrent,
)
sem = asyncio.Semaphore(max_concurrent)
async def _run_cluster(c: Any) -> dict[str, str]:
async with sem:
# Use the cluster's own sharded fact registry; fall back to
# the full registry if the cluster doesn't have one.
cluster_facts = getattr(c, "fact_shard", None) or {}
if not cluster_facts and fallback_fact_registry:
cluster_facts = fallback_fact_registry
# Snapshot current generated content as context for this cluster
return await generate_cluster(
cluster=c,
manifest_entries=manifest_entries,
fact_shard=cluster_facts,
output_dir=output_dir,
context_files=dict(all_generated),
model=model,
gen_log=gen_log,
)
results = await asyncio.gather(
*(_run_cluster(c) for c in level_clusters),
return_exceptions=True,
)
for i, result in enumerate(results):
if isinstance(result, Exception):
try:
cluster_ids = _get_cluster_file_ids(level_clusters[i])
except Exception:
cluster_ids = [f"<cluster index {i}>"]
logger.error(
"Cluster %s at level %d failed: %s",
cluster_ids,
level_num,
result,
)
else:
all_generated.update(result)
if gen_log:
summary = gen_log.summary()
logger.info("Generation complete. Summary: %s", summary)