mirror of
https://github.com/666ghj/MiroFish.git
synced 2026-04-28 06:31:25 +00:00
2716 lines
92 KiB
Python
2716 lines
92 KiB
Python
"""
|
||
模拟相关API路由
|
||
Step2: Zep实体读取与过滤、OASIS模拟准备与运行(全程自动化)
|
||
"""
|
||
|
||
import os
|
||
import traceback
|
||
from flask import request, jsonify, send_file
|
||
|
||
from . import simulation_bp
|
||
from ..config import Config
|
||
from ..services.zep_entity_reader import ZepEntityReader
|
||
from ..services.oasis_profile_generator import OasisProfileGenerator
|
||
from ..services.simulation_manager import SimulationManager, SimulationStatus
|
||
from ..services.simulation_runner import SimulationRunner, RunnerStatus
|
||
from ..utils.logger import get_logger
|
||
from ..utils.locale import t, get_locale, set_locale
|
||
from ..models.project import ProjectManager
|
||
|
||
logger = get_logger('mirofish.api.simulation')
|
||
|
||
|
||
# Interview prompt 优化前缀
|
||
# 添加此前缀可以避免Agent调用工具,直接用文本回复
|
||
INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:"
|
||
|
||
|
||
def optimize_interview_prompt(prompt: str) -> str:
|
||
"""
|
||
优化Interview提问,添加前缀避免Agent调用工具
|
||
|
||
Args:
|
||
prompt: 原始提问
|
||
|
||
Returns:
|
||
优化后的提问
|
||
"""
|
||
if not prompt:
|
||
return prompt
|
||
# 避免重复添加前缀
|
||
if prompt.startswith(INTERVIEW_PROMPT_PREFIX):
|
||
return prompt
|
||
return f"{INTERVIEW_PROMPT_PREFIX}{prompt}"
|
||
|
||
|
||
# ============== 实体读取接口 ==============
|
||
|
||
@simulation_bp.route('/entities/<graph_id>', methods=['GET'])
|
||
def get_graph_entities(graph_id: str):
|
||
"""
|
||
获取图谱中的所有实体(已过滤)
|
||
|
||
只返回符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||
|
||
Query参数:
|
||
entity_types: 逗号分隔的实体类型列表(可选,用于进一步过滤)
|
||
enrich: 是否获取相关边信息(默认true)
|
||
"""
|
||
try:
|
||
if not Config.ZEP_API_KEY:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.zepApiKeyMissing')
|
||
}), 500
|
||
|
||
entity_types_str = request.args.get('entity_types', '')
|
||
entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None
|
||
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
||
|
||
logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}")
|
||
|
||
reader = ZepEntityReader()
|
||
result = reader.filter_defined_entities(
|
||
graph_id=graph_id,
|
||
defined_entity_types=entity_types,
|
||
enrich_with_edges=enrich
|
||
)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": result.to_dict()
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取图谱实体失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/entities/<graph_id>/<entity_uuid>', methods=['GET'])
|
||
def get_entity_detail(graph_id: str, entity_uuid: str):
|
||
"""获取单个实体的详细信息"""
|
||
try:
|
||
if not Config.ZEP_API_KEY:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.zepApiKeyMissing')
|
||
}), 500
|
||
|
||
reader = ZepEntityReader()
|
||
entity = reader.get_entity_with_context(graph_id, entity_uuid)
|
||
|
||
if not entity:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.entityNotFound', id=entity_uuid)
|
||
}), 404
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": entity.to_dict()
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取实体详情失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/entities/<graph_id>/by-type/<entity_type>', methods=['GET'])
|
||
def get_entities_by_type(graph_id: str, entity_type: str):
|
||
"""获取指定类型的所有实体"""
|
||
try:
|
||
if not Config.ZEP_API_KEY:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.zepApiKeyMissing')
|
||
}), 500
|
||
|
||
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
||
|
||
reader = ZepEntityReader()
|
||
entities = reader.get_entities_by_type(
|
||
graph_id=graph_id,
|
||
entity_type=entity_type,
|
||
enrich_with_edges=enrich
|
||
)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"entity_type": entity_type,
|
||
"count": len(entities),
|
||
"entities": [e.to_dict() for e in entities]
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取实体失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
# ============== 模拟管理接口 ==============
|
||
|
||
@simulation_bp.route('/create', methods=['POST'])
|
||
def create_simulation():
|
||
"""
|
||
创建新的模拟
|
||
|
||
注意:max_rounds等参数由LLM智能生成,无需手动设置
|
||
|
||
请求(JSON):
|
||
{
|
||
"project_id": "proj_xxxx", // 必填
|
||
"graph_id": "mirofish_xxxx", // 可选,如不提供则从project获取
|
||
"enable_twitter": true, // 可选,默认true
|
||
"enable_reddit": true // 可选,默认true
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"project_id": "proj_xxxx",
|
||
"graph_id": "mirofish_xxxx",
|
||
"status": "created",
|
||
"enable_twitter": true,
|
||
"enable_reddit": true,
|
||
"created_at": "2025-12-01T10:00:00"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
project_id = data.get('project_id')
|
||
if not project_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireProjectId')
|
||
}), 400
|
||
|
||
project = ProjectManager.get_project(project_id)
|
||
if not project:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.projectNotFound', id=project_id)
|
||
}), 404
|
||
|
||
graph_id = data.get('graph_id') or project.graph_id
|
||
if not graph_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.graphNotBuilt')
|
||
}), 400
|
||
|
||
manager = SimulationManager()
|
||
state = manager.create_simulation(
|
||
project_id=project_id,
|
||
graph_id=graph_id,
|
||
enable_twitter=data.get('enable_twitter', True),
|
||
enable_reddit=data.get('enable_reddit', True),
|
||
)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": state.to_dict()
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建模拟失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
def _check_simulation_prepared(simulation_id: str) -> tuple:
|
||
"""
|
||
检查模拟是否已经准备完成
|
||
|
||
检查条件:
|
||
1. state.json 存在且 status 为 "ready"
|
||
2. 必要文件存在:reddit_profiles.json, twitter_profiles.csv, simulation_config.json
|
||
|
||
注意:运行脚本(run_*.py)保留在 backend/scripts/ 目录,不再复制到模拟目录
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
|
||
Returns:
|
||
(is_prepared: bool, info: dict)
|
||
"""
|
||
import os
|
||
from ..config import Config
|
||
|
||
simulation_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id)
|
||
|
||
# 检查目录是否存在
|
||
if not os.path.exists(simulation_dir):
|
||
return False, {"reason": "模拟目录不存在"}
|
||
|
||
# 必要文件列表(不包括脚本,脚本位于 backend/scripts/)
|
||
required_files = [
|
||
"state.json",
|
||
"simulation_config.json",
|
||
"reddit_profiles.json",
|
||
"twitter_profiles.csv"
|
||
]
|
||
|
||
# 检查文件是否存在
|
||
existing_files = []
|
||
missing_files = []
|
||
for f in required_files:
|
||
file_path = os.path.join(simulation_dir, f)
|
||
if os.path.exists(file_path):
|
||
existing_files.append(f)
|
||
else:
|
||
missing_files.append(f)
|
||
|
||
if missing_files:
|
||
return False, {
|
||
"reason": "缺少必要文件",
|
||
"missing_files": missing_files,
|
||
"existing_files": existing_files
|
||
}
|
||
|
||
# 检查state.json中的状态
|
||
state_file = os.path.join(simulation_dir, "state.json")
|
||
try:
|
||
import json
|
||
with open(state_file, 'r', encoding='utf-8') as f:
|
||
state_data = json.load(f)
|
||
|
||
status = state_data.get("status", "")
|
||
config_generated = state_data.get("config_generated", False)
|
||
|
||
# 详细日志
|
||
logger.debug(f"检测模拟准备状态: {simulation_id}, status={status}, config_generated={config_generated}")
|
||
|
||
# 如果 config_generated=True 且文件存在,认为准备完成
|
||
# 以下状态都说明准备工作已完成:
|
||
# - ready: 准备完成,可以运行
|
||
# - preparing: 如果 config_generated=True 说明已完成
|
||
# - running: 正在运行,说明准备早就完成了
|
||
# - completed: 运行完成,说明准备早就完成了
|
||
# - stopped: 已停止,说明准备早就完成了
|
||
# - failed: 运行失败(但准备是完成的)
|
||
prepared_statuses = ["ready", "preparing", "running", "completed", "stopped", "failed"]
|
||
if status in prepared_statuses and config_generated:
|
||
# 获取文件统计信息
|
||
profiles_file = os.path.join(simulation_dir, "reddit_profiles.json")
|
||
config_file = os.path.join(simulation_dir, "simulation_config.json")
|
||
|
||
profiles_count = 0
|
||
if os.path.exists(profiles_file):
|
||
with open(profiles_file, 'r', encoding='utf-8') as f:
|
||
profiles_data = json.load(f)
|
||
profiles_count = len(profiles_data) if isinstance(profiles_data, list) else 0
|
||
|
||
# 如果状态是preparing但文件已完成,自动更新状态为ready
|
||
if status == "preparing":
|
||
try:
|
||
state_data["status"] = "ready"
|
||
from datetime import datetime
|
||
state_data["updated_at"] = datetime.now().isoformat()
|
||
with open(state_file, 'w', encoding='utf-8') as f:
|
||
json.dump(state_data, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"自动更新模拟状态: {simulation_id} preparing -> ready")
|
||
status = "ready"
|
||
except Exception as e:
|
||
logger.warning(f"自动更新状态失败: {e}")
|
||
|
||
logger.info(f"模拟 {simulation_id} 检测结果: 已准备完成 (status={status}, config_generated={config_generated})")
|
||
return True, {
|
||
"status": status,
|
||
"entities_count": state_data.get("entities_count", 0),
|
||
"profiles_count": profiles_count,
|
||
"entity_types": state_data.get("entity_types", []),
|
||
"config_generated": config_generated,
|
||
"created_at": state_data.get("created_at"),
|
||
"updated_at": state_data.get("updated_at"),
|
||
"existing_files": existing_files
|
||
}
|
||
else:
|
||
logger.warning(f"模拟 {simulation_id} 检测结果: 未准备完成 (status={status}, config_generated={config_generated})")
|
||
return False, {
|
||
"reason": f"状态不在已准备列表中或config_generated为false: status={status}, config_generated={config_generated}",
|
||
"status": status,
|
||
"config_generated": config_generated
|
||
}
|
||
|
||
except Exception as e:
|
||
return False, {"reason": f"读取状态文件失败: {str(e)}"}
|
||
|
||
|
||
@simulation_bp.route('/prepare', methods=['POST'])
|
||
def prepare_simulation():
|
||
"""
|
||
准备模拟环境(异步任务,LLM智能生成所有参数)
|
||
|
||
这是一个耗时操作,接口会立即返回task_id,
|
||
使用 GET /api/simulation/prepare/status 查询进度
|
||
|
||
特性:
|
||
- 自动检测已完成的准备工作,避免重复生成
|
||
- 如果已准备完成,直接返回已有结果
|
||
- 支持强制重新生成(force_regenerate=true)
|
||
|
||
步骤:
|
||
1. 检查是否已有完成的准备工作
|
||
2. 从Zep图谱读取并过滤实体
|
||
3. 为每个实体生成OASIS Agent Profile(带重试机制)
|
||
4. LLM智能生成模拟配置(带重试机制)
|
||
5. 保存配置文件和预设脚本
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||
"entity_types": ["Student", "PublicFigure"], // 可选,指定实体类型
|
||
"use_llm_for_profiles": true, // 可选,是否用LLM生成人设
|
||
"parallel_profile_count": 5, // 可选,并行生成人设数量,默认5
|
||
"force_regenerate": false // 可选,强制重新生成,默认false
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"task_id": "task_xxxx", // 新任务时返回
|
||
"status": "preparing|ready",
|
||
"message": "准备任务已启动|已有完成的准备工作",
|
||
"already_prepared": true|false // 是否已准备完成
|
||
}
|
||
}
|
||
"""
|
||
import threading
|
||
import os
|
||
from ..models.task import TaskManager, TaskStatus
|
||
from ..config import Config
|
||
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
manager = SimulationManager()
|
||
state = manager.get_simulation(simulation_id)
|
||
|
||
if not state:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.simulationNotFound', id=simulation_id)
|
||
}), 404
|
||
|
||
# 检查是否强制重新生成
|
||
force_regenerate = data.get('force_regenerate', False)
|
||
logger.info(f"开始处理 /prepare 请求: simulation_id={simulation_id}, force_regenerate={force_regenerate}")
|
||
|
||
# 检查是否已经准备完成(避免重复生成)
|
||
if not force_regenerate:
|
||
logger.debug(f"检查模拟 {simulation_id} 是否已准备完成...")
|
||
is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
|
||
logger.debug(f"检查结果: is_prepared={is_prepared}, prepare_info={prepare_info}")
|
||
if is_prepared:
|
||
logger.info(f"模拟 {simulation_id} 已准备完成,跳过重复生成")
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"status": "ready",
|
||
"message": t('api.alreadyPrepared'),
|
||
"already_prepared": True,
|
||
"prepare_info": prepare_info
|
||
}
|
||
})
|
||
else:
|
||
logger.info(f"模拟 {simulation_id} 未准备完成,将启动准备任务")
|
||
|
||
# 从项目获取必要信息
|
||
project = ProjectManager.get_project(state.project_id)
|
||
if not project:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.projectNotFound', id=state.project_id)
|
||
}), 404
|
||
|
||
# 获取模拟需求
|
||
simulation_requirement = project.simulation_requirement or ""
|
||
if not simulation_requirement:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.projectMissingRequirement')
|
||
}), 400
|
||
|
||
# 获取文档文本
|
||
document_text = ProjectManager.get_extracted_text(state.project_id) or ""
|
||
|
||
entity_types_list = data.get('entity_types')
|
||
use_llm_for_profiles = data.get('use_llm_for_profiles', True)
|
||
parallel_profile_count = data.get('parallel_profile_count', 5)
|
||
|
||
# ========== 同步获取实体数量(在后台任务启动前) ==========
|
||
# 这样前端在调用prepare后立即就能获取到预期Agent总数
|
||
try:
|
||
logger.info(f"同步获取实体数量: graph_id={state.graph_id}")
|
||
reader = ZepEntityReader()
|
||
# 快速读取实体(不需要边信息,只统计数量)
|
||
filtered_preview = reader.filter_defined_entities(
|
||
graph_id=state.graph_id,
|
||
defined_entity_types=entity_types_list,
|
||
enrich_with_edges=False # 不获取边信息,加快速度
|
||
)
|
||
# 保存实体数量到状态(供前端立即获取)
|
||
state.entities_count = filtered_preview.filtered_count
|
||
state.entity_types = list(filtered_preview.entity_types)
|
||
logger.info(f"预期实体数量: {filtered_preview.filtered_count}, 类型: {filtered_preview.entity_types}")
|
||
except Exception as e:
|
||
logger.warning(f"同步获取实体数量失败(将在后台任务中重试): {e}")
|
||
# 失败不影响后续流程,后台任务会重新获取
|
||
|
||
# 创建异步任务
|
||
task_manager = TaskManager()
|
||
task_id = task_manager.create_task(
|
||
task_type="simulation_prepare",
|
||
metadata={
|
||
"simulation_id": simulation_id,
|
||
"project_id": state.project_id
|
||
}
|
||
)
|
||
|
||
# 更新模拟状态(包含预先获取的实体数量)
|
||
state.status = SimulationStatus.PREPARING
|
||
manager._save_simulation_state(state)
|
||
|
||
# Capture locale before spawning background thread
|
||
current_locale = get_locale()
|
||
|
||
# 定义后台任务
|
||
def run_prepare():
|
||
set_locale(current_locale)
|
||
try:
|
||
task_manager.update_task(
|
||
task_id,
|
||
status=TaskStatus.PROCESSING,
|
||
progress=0,
|
||
message=t('progress.startPreparingEnv')
|
||
)
|
||
|
||
# 准备模拟(带进度回调)
|
||
# 存储阶段进度详情
|
||
stage_details = {}
|
||
|
||
def progress_callback(stage, progress, message, **kwargs):
|
||
# 计算总进度
|
||
stage_weights = {
|
||
"reading": (0, 20), # 0-20%
|
||
"generating_profiles": (20, 70), # 20-70%
|
||
"generating_config": (70, 90), # 70-90%
|
||
"copying_scripts": (90, 100) # 90-100%
|
||
}
|
||
|
||
start, end = stage_weights.get(stage, (0, 100))
|
||
current_progress = int(start + (end - start) * progress / 100)
|
||
|
||
# 构建详细进度信息
|
||
stage_names = {
|
||
"reading": t('progress.readingGraphEntities'),
|
||
"generating_profiles": t('progress.generatingProfiles'),
|
||
"generating_config": t('progress.generatingSimConfig'),
|
||
"copying_scripts": t('progress.preparingScripts')
|
||
}
|
||
|
||
stage_index = list(stage_weights.keys()).index(stage) + 1 if stage in stage_weights else 1
|
||
total_stages = len(stage_weights)
|
||
|
||
# 更新阶段详情
|
||
stage_details[stage] = {
|
||
"stage_name": stage_names.get(stage, stage),
|
||
"stage_progress": progress,
|
||
"current": kwargs.get("current", 0),
|
||
"total": kwargs.get("total", 0),
|
||
"item_name": kwargs.get("item_name", "")
|
||
}
|
||
|
||
# 构建详细进度信息
|
||
detail = stage_details[stage]
|
||
progress_detail_data = {
|
||
"current_stage": stage,
|
||
"current_stage_name": stage_names.get(stage, stage),
|
||
"stage_index": stage_index,
|
||
"total_stages": total_stages,
|
||
"stage_progress": progress,
|
||
"current_item": detail["current"],
|
||
"total_items": detail["total"],
|
||
"item_description": message
|
||
}
|
||
|
||
# 构建简洁消息
|
||
if detail["total"] > 0:
|
||
detailed_message = (
|
||
f"[{stage_index}/{total_stages}] {stage_names.get(stage, stage)}: "
|
||
f"{detail['current']}/{detail['total']} - {message}"
|
||
)
|
||
else:
|
||
detailed_message = f"[{stage_index}/{total_stages}] {stage_names.get(stage, stage)}: {message}"
|
||
|
||
task_manager.update_task(
|
||
task_id,
|
||
progress=current_progress,
|
||
message=detailed_message,
|
||
progress_detail=progress_detail_data
|
||
)
|
||
|
||
result_state = manager.prepare_simulation(
|
||
simulation_id=simulation_id,
|
||
simulation_requirement=simulation_requirement,
|
||
document_text=document_text,
|
||
defined_entity_types=entity_types_list,
|
||
use_llm_for_profiles=use_llm_for_profiles,
|
||
progress_callback=progress_callback,
|
||
parallel_profile_count=parallel_profile_count
|
||
)
|
||
|
||
# 任务完成
|
||
task_manager.complete_task(
|
||
task_id,
|
||
result=result_state.to_simple_dict()
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"准备模拟失败: {str(e)}")
|
||
task_manager.fail_task(task_id, str(e))
|
||
|
||
# 更新模拟状态为失败
|
||
state = manager.get_simulation(simulation_id)
|
||
if state:
|
||
state.status = SimulationStatus.FAILED
|
||
state.error = str(e)
|
||
manager._save_simulation_state(state)
|
||
|
||
# 启动后台线程
|
||
thread = threading.Thread(target=run_prepare, daemon=True)
|
||
thread.start()
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"task_id": task_id,
|
||
"status": "preparing",
|
||
"message": t('api.prepareStarted'),
|
||
"already_prepared": False,
|
||
"expected_entities_count": state.entities_count, # 预期的Agent总数
|
||
"entity_types": state.entity_types # 实体类型列表
|
||
}
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 404
|
||
|
||
except Exception as e:
|
||
logger.error(f"启动准备任务失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/prepare/status', methods=['POST'])
|
||
def get_prepare_status():
|
||
"""
|
||
查询准备任务进度
|
||
|
||
支持两种查询方式:
|
||
1. 通过task_id查询正在进行的任务进度
|
||
2. 通过simulation_id检查是否已有完成的准备工作
|
||
|
||
请求(JSON):
|
||
{
|
||
"task_id": "task_xxxx", // 可选,prepare返回的task_id
|
||
"simulation_id": "sim_xxxx" // 可选,模拟ID(用于检查已完成的准备)
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"task_id": "task_xxxx",
|
||
"status": "processing|completed|ready",
|
||
"progress": 45,
|
||
"message": "...",
|
||
"already_prepared": true|false, // 是否已有完成的准备
|
||
"prepare_info": {...} // 已准备完成时的详细信息
|
||
}
|
||
}
|
||
"""
|
||
from ..models.task import TaskManager
|
||
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
task_id = data.get('task_id')
|
||
simulation_id = data.get('simulation_id')
|
||
|
||
# 如果提供了simulation_id,先检查是否已准备完成
|
||
if simulation_id:
|
||
is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
|
||
if is_prepared:
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"status": "ready",
|
||
"progress": 100,
|
||
"message": t('api.alreadyPrepared'),
|
||
"already_prepared": True,
|
||
"prepare_info": prepare_info
|
||
}
|
||
})
|
||
|
||
# 如果没有task_id,返回错误
|
||
if not task_id:
|
||
if simulation_id:
|
||
# 有simulation_id但未准备完成
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"status": "not_started",
|
||
"progress": 0,
|
||
"message": t('api.notStartedPrepare'),
|
||
"already_prepared": False
|
||
}
|
||
})
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireTaskOrSimId')
|
||
}), 400
|
||
|
||
task_manager = TaskManager()
|
||
task = task_manager.get_task(task_id)
|
||
|
||
if not task:
|
||
# 任务不存在,但如果有simulation_id,检查是否已准备完成
|
||
if simulation_id:
|
||
is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
|
||
if is_prepared:
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"task_id": task_id,
|
||
"status": "ready",
|
||
"progress": 100,
|
||
"message": t('api.taskCompletedPrepared'),
|
||
"already_prepared": True,
|
||
"prepare_info": prepare_info
|
||
}
|
||
})
|
||
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.taskNotFound', id=task_id)
|
||
}), 404
|
||
|
||
task_dict = task.to_dict()
|
||
task_dict["already_prepared"] = False
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": task_dict
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询任务状态失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>', methods=['GET'])
|
||
def get_simulation(simulation_id: str):
|
||
"""获取模拟状态"""
|
||
try:
|
||
manager = SimulationManager()
|
||
state = manager.get_simulation(simulation_id)
|
||
|
||
if not state:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.simulationNotFound', id=simulation_id)
|
||
}), 404
|
||
|
||
result = state.to_dict()
|
||
|
||
# 如果模拟已准备好,附加运行说明
|
||
if state.status == SimulationStatus.READY:
|
||
result["run_instructions"] = manager.get_run_instructions(simulation_id)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": result
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取模拟状态失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/list', methods=['GET'])
|
||
def list_simulations():
|
||
"""
|
||
列出所有模拟
|
||
|
||
Query参数:
|
||
project_id: 按项目ID过滤(可选)
|
||
"""
|
||
try:
|
||
project_id = request.args.get('project_id')
|
||
|
||
manager = SimulationManager()
|
||
simulations = manager.list_simulations(project_id=project_id)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": [s.to_dict() for s in simulations],
|
||
"count": len(simulations)
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"列出模拟失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
def _get_report_id_for_simulation(simulation_id: str) -> str:
|
||
"""
|
||
获取 simulation 对应的最新 report_id
|
||
|
||
遍历 reports 目录,找出 simulation_id 匹配的 report,
|
||
如果有多个则返回最新的(按 created_at 排序)
|
||
|
||
Args:
|
||
simulation_id: 模拟ID
|
||
|
||
Returns:
|
||
report_id 或 None
|
||
"""
|
||
import json
|
||
from datetime import datetime
|
||
|
||
# reports 目录路径:backend/uploads/reports
|
||
# __file__ 是 app/api/simulation.py,需要向上两级到 backend/
|
||
reports_dir = os.path.join(os.path.dirname(__file__), '../../uploads/reports')
|
||
if not os.path.exists(reports_dir):
|
||
return None
|
||
|
||
matching_reports = []
|
||
|
||
try:
|
||
for report_folder in os.listdir(reports_dir):
|
||
report_path = os.path.join(reports_dir, report_folder)
|
||
if not os.path.isdir(report_path):
|
||
continue
|
||
|
||
meta_file = os.path.join(report_path, "meta.json")
|
||
if not os.path.exists(meta_file):
|
||
continue
|
||
|
||
try:
|
||
with open(meta_file, 'r', encoding='utf-8') as f:
|
||
meta = json.load(f)
|
||
|
||
if meta.get("simulation_id") == simulation_id:
|
||
matching_reports.append({
|
||
"report_id": meta.get("report_id"),
|
||
"created_at": meta.get("created_at", ""),
|
||
"status": meta.get("status", "")
|
||
})
|
||
except Exception:
|
||
continue
|
||
|
||
if not matching_reports:
|
||
return None
|
||
|
||
# 按创建时间倒序排序,返回最新的
|
||
matching_reports.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
||
return matching_reports[0].get("report_id")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"查找 simulation {simulation_id} 的 report 失败: {e}")
|
||
return None
|
||
|
||
|
||
@simulation_bp.route('/history', methods=['GET'])
|
||
def get_simulation_history():
|
||
"""
|
||
获取历史模拟列表(带项目详情)
|
||
|
||
用于首页历史项目展示,返回包含项目名称、描述等丰富信息的模拟列表
|
||
|
||
Query参数:
|
||
limit: 返回数量限制(默认20)
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": [
|
||
{
|
||
"simulation_id": "sim_xxxx",
|
||
"project_id": "proj_xxxx",
|
||
"project_name": "武大舆情分析",
|
||
"simulation_requirement": "如果武汉大学发布...",
|
||
"status": "completed",
|
||
"entities_count": 68,
|
||
"profiles_count": 68,
|
||
"entity_types": ["Student", "Professor", ...],
|
||
"created_at": "2024-12-10",
|
||
"updated_at": "2024-12-10",
|
||
"total_rounds": 120,
|
||
"current_round": 120,
|
||
"report_id": "report_xxxx",
|
||
"version": "v1.0.2"
|
||
},
|
||
...
|
||
],
|
||
"count": 7
|
||
}
|
||
"""
|
||
try:
|
||
limit = request.args.get('limit', 20, type=int)
|
||
|
||
manager = SimulationManager()
|
||
simulations = manager.list_simulations()[:limit]
|
||
|
||
# 增强模拟数据,只从 Simulation 文件读取
|
||
enriched_simulations = []
|
||
for sim in simulations:
|
||
sim_dict = sim.to_dict()
|
||
|
||
# 获取模拟配置信息(从 simulation_config.json 读取 simulation_requirement)
|
||
config = manager.get_simulation_config(sim.simulation_id)
|
||
if config:
|
||
sim_dict["simulation_requirement"] = config.get("simulation_requirement", "")
|
||
time_config = config.get("time_config", {})
|
||
sim_dict["total_simulation_hours"] = time_config.get("total_simulation_hours", 0)
|
||
# 推荐轮数(后备值)
|
||
recommended_rounds = int(
|
||
time_config.get("total_simulation_hours", 0) * 60 /
|
||
max(time_config.get("minutes_per_round", 60), 1)
|
||
)
|
||
else:
|
||
sim_dict["simulation_requirement"] = ""
|
||
sim_dict["total_simulation_hours"] = 0
|
||
recommended_rounds = 0
|
||
|
||
# 获取运行状态(从 run_state.json 读取用户设置的实际轮数)
|
||
run_state = SimulationRunner.get_run_state(sim.simulation_id)
|
||
if run_state:
|
||
sim_dict["current_round"] = run_state.current_round
|
||
sim_dict["runner_status"] = run_state.runner_status.value
|
||
# 使用用户设置的 total_rounds,若无则使用推荐轮数
|
||
sim_dict["total_rounds"] = run_state.total_rounds if run_state.total_rounds > 0 else recommended_rounds
|
||
else:
|
||
sim_dict["current_round"] = 0
|
||
sim_dict["runner_status"] = "idle"
|
||
sim_dict["total_rounds"] = recommended_rounds
|
||
|
||
# 获取关联项目的文件列表(最多3个)
|
||
project = ProjectManager.get_project(sim.project_id)
|
||
if project and hasattr(project, 'files') and project.files:
|
||
sim_dict["files"] = [
|
||
{"filename": f.get("filename", "未知文件")}
|
||
for f in project.files[:3]
|
||
]
|
||
else:
|
||
sim_dict["files"] = []
|
||
|
||
# 获取关联的 report_id(查找该 simulation 最新的 report)
|
||
sim_dict["report_id"] = _get_report_id_for_simulation(sim.simulation_id)
|
||
|
||
# 添加版本号
|
||
sim_dict["version"] = "v1.0.2"
|
||
|
||
# 格式化日期
|
||
try:
|
||
created_date = sim_dict.get("created_at", "")[:10]
|
||
sim_dict["created_date"] = created_date
|
||
except:
|
||
sim_dict["created_date"] = ""
|
||
|
||
enriched_simulations.append(sim_dict)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": enriched_simulations,
|
||
"count": len(enriched_simulations)
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取历史模拟失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/profiles', methods=['GET'])
|
||
def get_simulation_profiles(simulation_id: str):
|
||
"""
|
||
获取模拟的Agent Profile
|
||
|
||
Query参数:
|
||
platform: 平台类型(reddit/twitter,默认reddit)
|
||
"""
|
||
try:
|
||
platform = request.args.get('platform', 'reddit')
|
||
|
||
manager = SimulationManager()
|
||
profiles = manager.get_profiles(simulation_id, platform=platform)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"platform": platform,
|
||
"count": len(profiles),
|
||
"profiles": profiles
|
||
}
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 404
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Profile失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/profiles/realtime', methods=['GET'])
|
||
def get_simulation_profiles_realtime(simulation_id: str):
|
||
"""
|
||
实时获取模拟的Agent Profile(用于在生成过程中实时查看进度)
|
||
|
||
与 /profiles 接口的区别:
|
||
- 直接读取文件,不经过 SimulationManager
|
||
- 适用于生成过程中的实时查看
|
||
- 返回额外的元数据(如文件修改时间、是否正在生成等)
|
||
|
||
Query参数:
|
||
platform: 平台类型(reddit/twitter,默认reddit)
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"platform": "reddit",
|
||
"count": 15,
|
||
"total_expected": 93, // 预期总数(如果有)
|
||
"is_generating": true, // 是否正在生成
|
||
"file_exists": true,
|
||
"file_modified_at": "2025-12-04T18:20:00",
|
||
"profiles": [...]
|
||
}
|
||
}
|
||
"""
|
||
import json
|
||
import csv
|
||
from datetime import datetime
|
||
|
||
try:
|
||
platform = request.args.get('platform', 'reddit')
|
||
|
||
# 获取模拟目录
|
||
sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id)
|
||
|
||
if not os.path.exists(sim_dir):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.simulationNotFound', id=simulation_id)
|
||
}), 404
|
||
|
||
# 确定文件路径
|
||
if platform == "reddit":
|
||
profiles_file = os.path.join(sim_dir, "reddit_profiles.json")
|
||
else:
|
||
profiles_file = os.path.join(sim_dir, "twitter_profiles.csv")
|
||
|
||
# 检查文件是否存在
|
||
file_exists = os.path.exists(profiles_file)
|
||
profiles = []
|
||
file_modified_at = None
|
||
|
||
if file_exists:
|
||
# 获取文件修改时间
|
||
file_stat = os.stat(profiles_file)
|
||
file_modified_at = datetime.fromtimestamp(file_stat.st_mtime).isoformat()
|
||
|
||
try:
|
||
if platform == "reddit":
|
||
with open(profiles_file, 'r', encoding='utf-8') as f:
|
||
profiles = json.load(f)
|
||
else:
|
||
with open(profiles_file, 'r', encoding='utf-8') as f:
|
||
reader = csv.DictReader(f)
|
||
profiles = list(reader)
|
||
except (json.JSONDecodeError, Exception) as e:
|
||
logger.warning(f"读取 profiles 文件失败(可能正在写入中): {e}")
|
||
profiles = []
|
||
|
||
# 检查是否正在生成(通过 state.json 判断)
|
||
is_generating = False
|
||
total_expected = None
|
||
|
||
state_file = os.path.join(sim_dir, "state.json")
|
||
if os.path.exists(state_file):
|
||
try:
|
||
with open(state_file, 'r', encoding='utf-8') as f:
|
||
state_data = json.load(f)
|
||
status = state_data.get("status", "")
|
||
is_generating = status == "preparing"
|
||
total_expected = state_data.get("entities_count")
|
||
except Exception:
|
||
pass
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"platform": platform,
|
||
"count": len(profiles),
|
||
"total_expected": total_expected,
|
||
"is_generating": is_generating,
|
||
"file_exists": file_exists,
|
||
"file_modified_at": file_modified_at,
|
||
"profiles": profiles
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"实时获取Profile失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/config/realtime', methods=['GET'])
|
||
def get_simulation_config_realtime(simulation_id: str):
|
||
"""
|
||
实时获取模拟配置(用于在生成过程中实时查看进度)
|
||
|
||
与 /config 接口的区别:
|
||
- 直接读取文件,不经过 SimulationManager
|
||
- 适用于生成过程中的实时查看
|
||
- 返回额外的元数据(如文件修改时间、是否正在生成等)
|
||
- 即使配置还没生成完也能返回部分信息
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"file_exists": true,
|
||
"file_modified_at": "2025-12-04T18:20:00",
|
||
"is_generating": true, // 是否正在生成
|
||
"generation_stage": "generating_config", // 当前生成阶段
|
||
"config": {...} // 配置内容(如果存在)
|
||
}
|
||
}
|
||
"""
|
||
import json
|
||
from datetime import datetime
|
||
|
||
try:
|
||
# 获取模拟目录
|
||
sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id)
|
||
|
||
if not os.path.exists(sim_dir):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.simulationNotFound', id=simulation_id)
|
||
}), 404
|
||
|
||
# 配置文件路径
|
||
config_file = os.path.join(sim_dir, "simulation_config.json")
|
||
|
||
# 检查文件是否存在
|
||
file_exists = os.path.exists(config_file)
|
||
config = None
|
||
file_modified_at = None
|
||
|
||
if file_exists:
|
||
# 获取文件修改时间
|
||
file_stat = os.stat(config_file)
|
||
file_modified_at = datetime.fromtimestamp(file_stat.st_mtime).isoformat()
|
||
|
||
try:
|
||
with open(config_file, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
except (json.JSONDecodeError, Exception) as e:
|
||
logger.warning(f"读取 config 文件失败(可能正在写入中): {e}")
|
||
config = None
|
||
|
||
# 检查是否正在生成(通过 state.json 判断)
|
||
is_generating = False
|
||
generation_stage = None
|
||
config_generated = False
|
||
|
||
state_file = os.path.join(sim_dir, "state.json")
|
||
if os.path.exists(state_file):
|
||
try:
|
||
with open(state_file, 'r', encoding='utf-8') as f:
|
||
state_data = json.load(f)
|
||
status = state_data.get("status", "")
|
||
is_generating = status == "preparing"
|
||
config_generated = state_data.get("config_generated", False)
|
||
|
||
# 判断当前阶段
|
||
if is_generating:
|
||
if state_data.get("profiles_generated", False):
|
||
generation_stage = "generating_config"
|
||
else:
|
||
generation_stage = "generating_profiles"
|
||
elif status == "ready":
|
||
generation_stage = "completed"
|
||
except Exception:
|
||
pass
|
||
|
||
# 构建返回数据
|
||
response_data = {
|
||
"simulation_id": simulation_id,
|
||
"file_exists": file_exists,
|
||
"file_modified_at": file_modified_at,
|
||
"is_generating": is_generating,
|
||
"generation_stage": generation_stage,
|
||
"config_generated": config_generated,
|
||
"config": config
|
||
}
|
||
|
||
# 如果配置存在,提取一些关键统计信息
|
||
if config:
|
||
response_data["summary"] = {
|
||
"total_agents": len(config.get("agent_configs", [])),
|
||
"simulation_hours": config.get("time_config", {}).get("total_simulation_hours"),
|
||
"initial_posts_count": len(config.get("event_config", {}).get("initial_posts", [])),
|
||
"hot_topics_count": len(config.get("event_config", {}).get("hot_topics", [])),
|
||
"has_twitter_config": "twitter_config" in config,
|
||
"has_reddit_config": "reddit_config" in config,
|
||
"generated_at": config.get("generated_at"),
|
||
"llm_model": config.get("llm_model")
|
||
}
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": response_data
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"实时获取Config失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/config', methods=['GET'])
|
||
def get_simulation_config(simulation_id: str):
|
||
"""
|
||
获取模拟配置(LLM智能生成的完整配置)
|
||
|
||
返回包含:
|
||
- time_config: 时间配置(模拟时长、轮次、高峰/低谷时段)
|
||
- agent_configs: 每个Agent的活动配置(活跃度、发言频率、立场等)
|
||
- event_config: 事件配置(初始帖子、热点话题)
|
||
- platform_configs: 平台配置
|
||
- generation_reasoning: LLM的配置推理说明
|
||
"""
|
||
try:
|
||
manager = SimulationManager()
|
||
config = manager.get_simulation_config(simulation_id)
|
||
|
||
if not config:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.configNotFound')
|
||
}), 404
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": config
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取配置失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/config/download', methods=['GET'])
|
||
def download_simulation_config(simulation_id: str):
|
||
"""下载模拟配置文件"""
|
||
try:
|
||
manager = SimulationManager()
|
||
sim_dir = manager._get_simulation_dir(simulation_id)
|
||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||
|
||
if not os.path.exists(config_path):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.configFileNotFound')
|
||
}), 404
|
||
|
||
return send_file(
|
||
config_path,
|
||
as_attachment=True,
|
||
download_name="simulation_config.json"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"下载配置失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/script/<script_name>/download', methods=['GET'])
|
||
def download_simulation_script(script_name: str):
|
||
"""
|
||
下载模拟运行脚本文件(通用脚本,位于 backend/scripts/)
|
||
|
||
script_name可选值:
|
||
- run_twitter_simulation.py
|
||
- run_reddit_simulation.py
|
||
- run_parallel_simulation.py
|
||
- action_logger.py
|
||
"""
|
||
try:
|
||
# 脚本位于 backend/scripts/ 目录
|
||
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
||
|
||
# 验证脚本名称
|
||
allowed_scripts = [
|
||
"run_twitter_simulation.py",
|
||
"run_reddit_simulation.py",
|
||
"run_parallel_simulation.py",
|
||
"action_logger.py"
|
||
]
|
||
|
||
if script_name not in allowed_scripts:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.unknownScript', name=script_name, allowed=allowed_scripts)
|
||
}), 400
|
||
|
||
script_path = os.path.join(scripts_dir, script_name)
|
||
|
||
if not os.path.exists(script_path):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.scriptFileNotFound', name=script_name)
|
||
}), 404
|
||
|
||
return send_file(
|
||
script_path,
|
||
as_attachment=True,
|
||
download_name=script_name
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"下载脚本失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
# ============== Profile生成接口(独立使用) ==============
|
||
|
||
@simulation_bp.route('/generate-profiles', methods=['POST'])
|
||
def generate_profiles():
|
||
"""
|
||
直接从图谱生成OASIS Agent Profile(不创建模拟)
|
||
|
||
请求(JSON):
|
||
{
|
||
"graph_id": "mirofish_xxxx", // 必填
|
||
"entity_types": ["Student"], // 可选
|
||
"use_llm": true, // 可选
|
||
"platform": "reddit" // 可选
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
graph_id = data.get('graph_id')
|
||
if not graph_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireGraphId')
|
||
}), 400
|
||
|
||
entity_types = data.get('entity_types')
|
||
use_llm = data.get('use_llm', True)
|
||
platform = data.get('platform', 'reddit')
|
||
|
||
reader = ZepEntityReader()
|
||
filtered = reader.filter_defined_entities(
|
||
graph_id=graph_id,
|
||
defined_entity_types=entity_types,
|
||
enrich_with_edges=True
|
||
)
|
||
|
||
if filtered.filtered_count == 0:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.noMatchingEntities')
|
||
}), 400
|
||
|
||
generator = OasisProfileGenerator()
|
||
profiles = generator.generate_profiles_from_entities(
|
||
entities=filtered.entities,
|
||
use_llm=use_llm
|
||
)
|
||
|
||
if platform == "reddit":
|
||
profiles_data = [p.to_reddit_format() for p in profiles]
|
||
elif platform == "twitter":
|
||
profiles_data = [p.to_twitter_format() for p in profiles]
|
||
else:
|
||
profiles_data = [p.to_dict() for p in profiles]
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"platform": platform,
|
||
"entity_types": list(filtered.entity_types),
|
||
"count": len(profiles_data),
|
||
"profiles": profiles_data
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成Profile失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
# ============== 模拟运行控制接口 ==============
|
||
|
||
@simulation_bp.route('/start', methods=['POST'])
|
||
def start_simulation():
|
||
"""
|
||
开始运行模拟
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||
"platform": "parallel", // 可选: twitter / reddit / parallel (默认)
|
||
"max_rounds": 100, // 可选: 最大模拟轮数,用于截断过长的模拟
|
||
"enable_graph_memory_update": false, // 可选: 是否将Agent活动动态更新到Zep图谱记忆
|
||
"force": false // 可选: 强制重新开始(会停止运行中的模拟并清理日志)
|
||
}
|
||
|
||
关于 force 参数:
|
||
- 启用后,如果模拟正在运行或已完成,会先停止并清理运行日志
|
||
- 清理的内容包括:run_state.json, actions.jsonl, simulation.log 等
|
||
- 不会清理配置文件(simulation_config.json)和 profile 文件
|
||
- 适用于需要重新运行模拟的场景
|
||
|
||
关于 enable_graph_memory_update:
|
||
- 启用后,模拟中所有Agent的活动(发帖、评论、点赞等)都会实时更新到Zep图谱
|
||
- 这可以让图谱"记住"模拟过程,用于后续分析或AI对话
|
||
- 需要模拟关联的项目有有效的 graph_id
|
||
- 采用批量更新机制,减少API调用次数
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"runner_status": "running",
|
||
"process_pid": 12345,
|
||
"twitter_running": true,
|
||
"reddit_running": true,
|
||
"started_at": "2025-12-01T10:00:00",
|
||
"graph_memory_update_enabled": true, // 是否启用了图谱记忆更新
|
||
"force_restarted": true // 是否是强制重新开始
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
platform = data.get('platform', 'parallel')
|
||
max_rounds = data.get('max_rounds') # 可选:最大模拟轮数
|
||
enable_graph_memory_update = data.get('enable_graph_memory_update', False) # 可选:是否启用图谱记忆更新
|
||
force = data.get('force', False) # 可选:强制重新开始
|
||
|
||
# 验证 max_rounds 参数
|
||
if max_rounds is not None:
|
||
try:
|
||
max_rounds = int(max_rounds)
|
||
if max_rounds <= 0:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.maxRoundsPositive')
|
||
}), 400
|
||
except (ValueError, TypeError):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.maxRoundsInvalid')
|
||
}), 400
|
||
|
||
if platform not in ['twitter', 'reddit', 'parallel']:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.invalidPlatform', platform=platform)
|
||
}), 400
|
||
|
||
# 检查模拟是否已准备好
|
||
manager = SimulationManager()
|
||
state = manager.get_simulation(simulation_id)
|
||
|
||
if not state:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.simulationNotFound', id=simulation_id)
|
||
}), 404
|
||
|
||
force_restarted = False
|
||
|
||
# 智能处理状态:如果准备工作已完成,允许重新启动
|
||
if state.status != SimulationStatus.READY:
|
||
# 检查准备工作是否已完成
|
||
is_prepared, prepare_info = _check_simulation_prepared(simulation_id)
|
||
|
||
if is_prepared:
|
||
# 准备工作已完成,检查是否有正在运行的进程
|
||
if state.status == SimulationStatus.RUNNING:
|
||
# 检查模拟进程是否真的在运行
|
||
run_state = SimulationRunner.get_run_state(simulation_id)
|
||
if run_state and run_state.runner_status.value == "running":
|
||
# 进程确实在运行
|
||
if force:
|
||
# 强制模式:停止运行中的模拟
|
||
logger.info(f"强制模式:停止运行中的模拟 {simulation_id}")
|
||
try:
|
||
SimulationRunner.stop_simulation(simulation_id)
|
||
except Exception as e:
|
||
logger.warning(f"停止模拟时出现警告: {str(e)}")
|
||
else:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.simRunningForceHint')
|
||
}), 400
|
||
|
||
# 如果是强制模式,清理运行日志
|
||
if force:
|
||
logger.info(f"强制模式:清理模拟日志 {simulation_id}")
|
||
cleanup_result = SimulationRunner.cleanup_simulation_logs(simulation_id)
|
||
if not cleanup_result.get("success"):
|
||
logger.warning(f"清理日志时出现警告: {cleanup_result.get('errors')}")
|
||
force_restarted = True
|
||
|
||
# 进程不存在或已结束,重置状态为 ready
|
||
logger.info(f"模拟 {simulation_id} 准备工作已完成,重置状态为 ready(原状态: {state.status.value})")
|
||
state.status = SimulationStatus.READY
|
||
manager._save_simulation_state(state)
|
||
else:
|
||
# 准备工作未完成
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.simNotReady', status=state.status.value)
|
||
}), 400
|
||
|
||
# 获取图谱ID(用于图谱记忆更新)
|
||
graph_id = None
|
||
if enable_graph_memory_update:
|
||
# 从模拟状态或项目中获取 graph_id
|
||
graph_id = state.graph_id
|
||
if not graph_id:
|
||
# 尝试从项目中获取
|
||
project = ProjectManager.get_project(state.project_id)
|
||
if project:
|
||
graph_id = project.graph_id
|
||
|
||
if not graph_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.graphIdRequiredForMemory')
|
||
}), 400
|
||
|
||
logger.info(f"启用图谱记忆更新: simulation_id={simulation_id}, graph_id={graph_id}")
|
||
|
||
# 启动模拟
|
||
run_state = SimulationRunner.start_simulation(
|
||
simulation_id=simulation_id,
|
||
platform=platform,
|
||
max_rounds=max_rounds,
|
||
enable_graph_memory_update=enable_graph_memory_update,
|
||
graph_id=graph_id
|
||
)
|
||
|
||
# 更新模拟状态
|
||
state.status = SimulationStatus.RUNNING
|
||
manager._save_simulation_state(state)
|
||
|
||
response_data = run_state.to_dict()
|
||
if max_rounds:
|
||
response_data['max_rounds_applied'] = max_rounds
|
||
response_data['graph_memory_update_enabled'] = enable_graph_memory_update
|
||
response_data['force_restarted'] = force_restarted
|
||
if enable_graph_memory_update:
|
||
response_data['graph_id'] = graph_id
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": response_data
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 400
|
||
|
||
except Exception as e:
|
||
logger.error(f"启动模拟失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/stop', methods=['POST'])
|
||
def stop_simulation():
|
||
"""
|
||
停止模拟
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx" // 必填,模拟ID
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"runner_status": "stopped",
|
||
"completed_at": "2025-12-01T12:00:00"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
run_state = SimulationRunner.stop_simulation(simulation_id)
|
||
|
||
# 更新模拟状态
|
||
manager = SimulationManager()
|
||
state = manager.get_simulation(simulation_id)
|
||
if state:
|
||
state.status = SimulationStatus.PAUSED
|
||
manager._save_simulation_state(state)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": run_state.to_dict()
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 400
|
||
|
||
except Exception as e:
|
||
logger.error(f"停止模拟失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
# ============== 实时状态监控接口 ==============
|
||
|
||
@simulation_bp.route('/<simulation_id>/run-status', methods=['GET'])
|
||
def get_run_status(simulation_id: str):
|
||
"""
|
||
获取模拟运行实时状态(用于前端轮询)
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"runner_status": "running",
|
||
"current_round": 5,
|
||
"total_rounds": 144,
|
||
"progress_percent": 3.5,
|
||
"simulated_hours": 2,
|
||
"total_simulation_hours": 72,
|
||
"twitter_running": true,
|
||
"reddit_running": true,
|
||
"twitter_actions_count": 150,
|
||
"reddit_actions_count": 200,
|
||
"total_actions_count": 350,
|
||
"started_at": "2025-12-01T10:00:00",
|
||
"updated_at": "2025-12-01T10:30:00"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
run_state = SimulationRunner.get_run_state(simulation_id)
|
||
|
||
if not run_state:
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"runner_status": "idle",
|
||
"current_round": 0,
|
||
"total_rounds": 0,
|
||
"progress_percent": 0,
|
||
"twitter_actions_count": 0,
|
||
"reddit_actions_count": 0,
|
||
"total_actions_count": 0,
|
||
}
|
||
})
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": run_state.to_dict()
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取运行状态失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/run-status/detail', methods=['GET'])
|
||
def get_run_status_detail(simulation_id: str):
|
||
"""
|
||
获取模拟运行详细状态(包含所有动作)
|
||
|
||
用于前端展示实时动态
|
||
|
||
Query参数:
|
||
platform: 过滤平台(twitter/reddit,可选)
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"runner_status": "running",
|
||
"current_round": 5,
|
||
...
|
||
"all_actions": [
|
||
{
|
||
"round_num": 5,
|
||
"timestamp": "2025-12-01T10:30:00",
|
||
"platform": "twitter",
|
||
"agent_id": 3,
|
||
"agent_name": "Agent Name",
|
||
"action_type": "CREATE_POST",
|
||
"action_args": {"content": "..."},
|
||
"result": null,
|
||
"success": true
|
||
},
|
||
...
|
||
],
|
||
"twitter_actions": [...], # Twitter 平台的所有动作
|
||
"reddit_actions": [...] # Reddit 平台的所有动作
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
run_state = SimulationRunner.get_run_state(simulation_id)
|
||
platform_filter = request.args.get('platform')
|
||
|
||
if not run_state:
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"runner_status": "idle",
|
||
"all_actions": [],
|
||
"twitter_actions": [],
|
||
"reddit_actions": []
|
||
}
|
||
})
|
||
|
||
# 获取完整的动作列表
|
||
all_actions = SimulationRunner.get_all_actions(
|
||
simulation_id=simulation_id,
|
||
platform=platform_filter
|
||
)
|
||
|
||
# 分平台获取动作
|
||
twitter_actions = SimulationRunner.get_all_actions(
|
||
simulation_id=simulation_id,
|
||
platform="twitter"
|
||
) if not platform_filter or platform_filter == "twitter" else []
|
||
|
||
reddit_actions = SimulationRunner.get_all_actions(
|
||
simulation_id=simulation_id,
|
||
platform="reddit"
|
||
) if not platform_filter or platform_filter == "reddit" else []
|
||
|
||
# 获取当前轮次的动作(recent_actions 只展示最新一轮)
|
||
current_round = run_state.current_round
|
||
recent_actions = SimulationRunner.get_all_actions(
|
||
simulation_id=simulation_id,
|
||
platform=platform_filter,
|
||
round_num=current_round
|
||
) if current_round > 0 else []
|
||
|
||
# 获取基础状态信息
|
||
result = run_state.to_dict()
|
||
result["all_actions"] = [a.to_dict() for a in all_actions]
|
||
result["twitter_actions"] = [a.to_dict() for a in twitter_actions]
|
||
result["reddit_actions"] = [a.to_dict() for a in reddit_actions]
|
||
result["rounds_count"] = len(run_state.rounds)
|
||
# recent_actions 只展示当前最新一轮两个平台的内容
|
||
result["recent_actions"] = [a.to_dict() for a in recent_actions]
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": result
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取详细状态失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/actions', methods=['GET'])
|
||
def get_simulation_actions(simulation_id: str):
|
||
"""
|
||
获取模拟中的Agent动作历史
|
||
|
||
Query参数:
|
||
limit: 返回数量(默认100)
|
||
offset: 偏移量(默认0)
|
||
platform: 过滤平台(twitter/reddit)
|
||
agent_id: 过滤Agent ID
|
||
round_num: 过滤轮次
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"count": 100,
|
||
"actions": [...]
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
limit = request.args.get('limit', 100, type=int)
|
||
offset = request.args.get('offset', 0, type=int)
|
||
platform = request.args.get('platform')
|
||
agent_id = request.args.get('agent_id', type=int)
|
||
round_num = request.args.get('round_num', type=int)
|
||
|
||
actions = SimulationRunner.get_actions(
|
||
simulation_id=simulation_id,
|
||
limit=limit,
|
||
offset=offset,
|
||
platform=platform,
|
||
agent_id=agent_id,
|
||
round_num=round_num
|
||
)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"count": len(actions),
|
||
"actions": [a.to_dict() for a in actions]
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取动作历史失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/timeline', methods=['GET'])
|
||
def get_simulation_timeline(simulation_id: str):
|
||
"""
|
||
获取模拟时间线(按轮次汇总)
|
||
|
||
用于前端展示进度条和时间线视图
|
||
|
||
Query参数:
|
||
start_round: 起始轮次(默认0)
|
||
end_round: 结束轮次(默认全部)
|
||
|
||
返回每轮的汇总信息
|
||
"""
|
||
try:
|
||
start_round = request.args.get('start_round', 0, type=int)
|
||
end_round = request.args.get('end_round', type=int)
|
||
|
||
timeline = SimulationRunner.get_timeline(
|
||
simulation_id=simulation_id,
|
||
start_round=start_round,
|
||
end_round=end_round
|
||
)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"rounds_count": len(timeline),
|
||
"timeline": timeline
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取时间线失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/agent-stats', methods=['GET'])
|
||
def get_agent_stats(simulation_id: str):
|
||
"""
|
||
获取每个Agent的统计信息
|
||
|
||
用于前端展示Agent活跃度排行、动作分布等
|
||
"""
|
||
try:
|
||
stats = SimulationRunner.get_agent_stats(simulation_id)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"agents_count": len(stats),
|
||
"stats": stats
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Agent统计失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
# ============== 数据库查询接口 ==============
|
||
|
||
@simulation_bp.route('/<simulation_id>/posts', methods=['GET'])
|
||
def get_simulation_posts(simulation_id: str):
|
||
"""
|
||
获取模拟中的帖子
|
||
|
||
Query参数:
|
||
platform: 平台类型(twitter/reddit)
|
||
limit: 返回数量(默认50)
|
||
offset: 偏移量
|
||
|
||
返回帖子列表(从SQLite数据库读取)
|
||
"""
|
||
try:
|
||
platform = request.args.get('platform', 'reddit')
|
||
limit = request.args.get('limit', 50, type=int)
|
||
offset = request.args.get('offset', 0, type=int)
|
||
|
||
sim_dir = os.path.join(
|
||
os.path.dirname(__file__),
|
||
f'../../uploads/simulations/{simulation_id}'
|
||
)
|
||
|
||
db_file = f"{platform}_simulation.db"
|
||
db_path = os.path.join(sim_dir, db_file)
|
||
|
||
if not os.path.exists(db_path):
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"platform": platform,
|
||
"count": 0,
|
||
"posts": [],
|
||
"message": t('api.dbNotExist')
|
||
}
|
||
})
|
||
|
||
import sqlite3
|
||
conn = sqlite3.connect(db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
cursor = conn.cursor()
|
||
|
||
try:
|
||
cursor.execute("""
|
||
SELECT * FROM post
|
||
ORDER BY created_at DESC
|
||
LIMIT ? OFFSET ?
|
||
""", (limit, offset))
|
||
|
||
posts = [dict(row) for row in cursor.fetchall()]
|
||
|
||
cursor.execute("SELECT COUNT(*) FROM post")
|
||
total = cursor.fetchone()[0]
|
||
|
||
except sqlite3.OperationalError:
|
||
posts = []
|
||
total = 0
|
||
|
||
conn.close()
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"platform": platform,
|
||
"total": total,
|
||
"count": len(posts),
|
||
"posts": posts
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取帖子失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/<simulation_id>/comments', methods=['GET'])
|
||
def get_simulation_comments(simulation_id: str):
|
||
"""
|
||
获取模拟中的评论(仅Reddit)
|
||
|
||
Query参数:
|
||
post_id: 过滤帖子ID(可选)
|
||
limit: 返回数量
|
||
offset: 偏移量
|
||
"""
|
||
try:
|
||
post_id = request.args.get('post_id')
|
||
limit = request.args.get('limit', 50, type=int)
|
||
offset = request.args.get('offset', 0, type=int)
|
||
|
||
sim_dir = os.path.join(
|
||
os.path.dirname(__file__),
|
||
f'../../uploads/simulations/{simulation_id}'
|
||
)
|
||
|
||
db_path = os.path.join(sim_dir, "reddit_simulation.db")
|
||
|
||
if not os.path.exists(db_path):
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"count": 0,
|
||
"comments": []
|
||
}
|
||
})
|
||
|
||
import sqlite3
|
||
conn = sqlite3.connect(db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
cursor = conn.cursor()
|
||
|
||
try:
|
||
if post_id:
|
||
cursor.execute("""
|
||
SELECT * FROM comment
|
||
WHERE post_id = ?
|
||
ORDER BY created_at DESC
|
||
LIMIT ? OFFSET ?
|
||
""", (post_id, limit, offset))
|
||
else:
|
||
cursor.execute("""
|
||
SELECT * FROM comment
|
||
ORDER BY created_at DESC
|
||
LIMIT ? OFFSET ?
|
||
""", (limit, offset))
|
||
|
||
comments = [dict(row) for row in cursor.fetchall()]
|
||
|
||
except sqlite3.OperationalError:
|
||
comments = []
|
||
|
||
conn.close()
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"count": len(comments),
|
||
"comments": comments
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取评论失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
# ============== Interview 采访接口 ==============
|
||
|
||
@simulation_bp.route('/interview', methods=['POST'])
|
||
def interview_agent():
|
||
"""
|
||
采访单个Agent
|
||
|
||
注意:此功能需要模拟环境处于运行状态(完成模拟循环后进入等待命令模式)
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||
"agent_id": 0, // 必填,Agent ID
|
||
"prompt": "你对这件事有什么看法?", // 必填,采访问题
|
||
"platform": "twitter", // 可选,指定平台(twitter/reddit)
|
||
// 不指定时:双平台模拟同时采访两个平台
|
||
"timeout": 60 // 可选,超时时间(秒),默认60
|
||
}
|
||
|
||
返回(不指定platform,双平台模式):
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"agent_id": 0,
|
||
"prompt": "你对这件事有什么看法?",
|
||
"result": {
|
||
"agent_id": 0,
|
||
"prompt": "...",
|
||
"platforms": {
|
||
"twitter": {"agent_id": 0, "response": "...", "platform": "twitter"},
|
||
"reddit": {"agent_id": 0, "response": "...", "platform": "reddit"}
|
||
}
|
||
},
|
||
"timestamp": "2025-12-08T10:00:01"
|
||
}
|
||
}
|
||
|
||
返回(指定platform):
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"agent_id": 0,
|
||
"prompt": "你对这件事有什么看法?",
|
||
"result": {
|
||
"agent_id": 0,
|
||
"response": "我认为...",
|
||
"platform": "twitter",
|
||
"timestamp": "2025-12-08T10:00:00"
|
||
},
|
||
"timestamp": "2025-12-08T10:00:01"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
agent_id = data.get('agent_id')
|
||
prompt = data.get('prompt')
|
||
platform = data.get('platform') # 可选:twitter/reddit/None
|
||
timeout = data.get('timeout', 60)
|
||
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
if agent_id is None:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireAgentId')
|
||
}), 400
|
||
|
||
if not prompt:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requirePrompt')
|
||
}), 400
|
||
|
||
# 验证platform参数
|
||
if platform and platform not in ("twitter", "reddit"):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.invalidInterviewPlatform')
|
||
}), 400
|
||
|
||
# 检查环境状态
|
||
if not SimulationRunner.check_env_alive(simulation_id):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.envNotRunning')
|
||
}), 400
|
||
|
||
# 优化prompt,添加前缀避免Agent调用工具
|
||
optimized_prompt = optimize_interview_prompt(prompt)
|
||
|
||
result = SimulationRunner.interview_agent(
|
||
simulation_id=simulation_id,
|
||
agent_id=agent_id,
|
||
prompt=optimized_prompt,
|
||
platform=platform,
|
||
timeout=timeout
|
||
)
|
||
|
||
return jsonify({
|
||
"success": result.get("success", False),
|
||
"data": result
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 400
|
||
|
||
except TimeoutError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.interviewTimeout', error=str(e))
|
||
}), 504
|
||
|
||
except Exception as e:
|
||
logger.error(f"Interview失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/interview/batch', methods=['POST'])
|
||
def interview_agents_batch():
|
||
"""
|
||
批量采访多个Agent
|
||
|
||
注意:此功能需要模拟环境处于运行状态
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||
"interviews": [ // 必填,采访列表
|
||
{
|
||
"agent_id": 0,
|
||
"prompt": "你对A有什么看法?",
|
||
"platform": "twitter" // 可选,指定该Agent的采访平台
|
||
},
|
||
{
|
||
"agent_id": 1,
|
||
"prompt": "你对B有什么看法?" // 不指定platform则使用默认值
|
||
}
|
||
],
|
||
"platform": "reddit", // 可选,默认平台(被每项的platform覆盖)
|
||
// 不指定时:双平台模拟每个Agent同时采访两个平台
|
||
"timeout": 120 // 可选,超时时间(秒),默认120
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"interviews_count": 2,
|
||
"result": {
|
||
"interviews_count": 4,
|
||
"results": {
|
||
"twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"},
|
||
"reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"},
|
||
"twitter_1": {"agent_id": 1, "response": "...", "platform": "twitter"},
|
||
"reddit_1": {"agent_id": 1, "response": "...", "platform": "reddit"}
|
||
}
|
||
},
|
||
"timestamp": "2025-12-08T10:00:01"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
interviews = data.get('interviews')
|
||
platform = data.get('platform') # 可选:twitter/reddit/None
|
||
timeout = data.get('timeout', 120)
|
||
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
if not interviews or not isinstance(interviews, list):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireInterviews')
|
||
}), 400
|
||
|
||
# 验证platform参数
|
||
if platform and platform not in ("twitter", "reddit"):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.invalidInterviewPlatform')
|
||
}), 400
|
||
|
||
# 验证每个采访项
|
||
for i, interview in enumerate(interviews):
|
||
if 'agent_id' not in interview:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.interviewListMissingAgentId', index=i+1)
|
||
}), 400
|
||
if 'prompt' not in interview:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.interviewListMissingPrompt', index=i+1)
|
||
}), 400
|
||
# 验证每项的platform(如果有)
|
||
item_platform = interview.get('platform')
|
||
if item_platform and item_platform not in ("twitter", "reddit"):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.interviewListInvalidPlatform', index=i+1)
|
||
}), 400
|
||
|
||
# 检查环境状态
|
||
if not SimulationRunner.check_env_alive(simulation_id):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.envNotRunning')
|
||
}), 400
|
||
|
||
# 优化每个采访项的prompt,添加前缀避免Agent调用工具
|
||
optimized_interviews = []
|
||
for interview in interviews:
|
||
optimized_interview = interview.copy()
|
||
optimized_interview['prompt'] = optimize_interview_prompt(interview.get('prompt', ''))
|
||
optimized_interviews.append(optimized_interview)
|
||
|
||
result = SimulationRunner.interview_agents_batch(
|
||
simulation_id=simulation_id,
|
||
interviews=optimized_interviews,
|
||
platform=platform,
|
||
timeout=timeout
|
||
)
|
||
|
||
return jsonify({
|
||
"success": result.get("success", False),
|
||
"data": result
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 400
|
||
|
||
except TimeoutError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.batchInterviewTimeout', error=str(e))
|
||
}), 504
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量Interview失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/interview/all', methods=['POST'])
|
||
def interview_all_agents():
|
||
"""
|
||
全局采访 - 使用相同问题采访所有Agent
|
||
|
||
注意:此功能需要模拟环境处于运行状态
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||
"prompt": "你对这件事整体有什么看法?", // 必填,采访问题(所有Agent使用相同问题)
|
||
"platform": "reddit", // 可选,指定平台(twitter/reddit)
|
||
// 不指定时:双平台模拟每个Agent同时采访两个平台
|
||
"timeout": 180 // 可选,超时时间(秒),默认180
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"interviews_count": 50,
|
||
"result": {
|
||
"interviews_count": 100,
|
||
"results": {
|
||
"twitter_0": {"agent_id": 0, "response": "...", "platform": "twitter"},
|
||
"reddit_0": {"agent_id": 0, "response": "...", "platform": "reddit"},
|
||
...
|
||
}
|
||
},
|
||
"timestamp": "2025-12-08T10:00:01"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
prompt = data.get('prompt')
|
||
platform = data.get('platform') # 可选:twitter/reddit/None
|
||
timeout = data.get('timeout', 180)
|
||
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
if not prompt:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requirePrompt')
|
||
}), 400
|
||
|
||
# 验证platform参数
|
||
if platform and platform not in ("twitter", "reddit"):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.invalidInterviewPlatform')
|
||
}), 400
|
||
|
||
# 检查环境状态
|
||
if not SimulationRunner.check_env_alive(simulation_id):
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.envNotRunning')
|
||
}), 400
|
||
|
||
# 优化prompt,添加前缀避免Agent调用工具
|
||
optimized_prompt = optimize_interview_prompt(prompt)
|
||
|
||
result = SimulationRunner.interview_all_agents(
|
||
simulation_id=simulation_id,
|
||
prompt=optimized_prompt,
|
||
platform=platform,
|
||
timeout=timeout
|
||
)
|
||
|
||
return jsonify({
|
||
"success": result.get("success", False),
|
||
"data": result
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 400
|
||
|
||
except TimeoutError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.globalInterviewTimeout', error=str(e))
|
||
}), 504
|
||
|
||
except Exception as e:
|
||
logger.error(f"全局Interview失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/interview/history', methods=['POST'])
|
||
def get_interview_history():
|
||
"""
|
||
获取Interview历史记录
|
||
|
||
从模拟数据库中读取所有Interview记录
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||
"platform": "reddit", // 可选,平台类型(reddit/twitter)
|
||
// 不指定则返回两个平台的所有历史
|
||
"agent_id": 0, // 可选,只获取该Agent的采访历史
|
||
"limit": 100 // 可选,返回数量,默认100
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"count": 10,
|
||
"history": [
|
||
{
|
||
"agent_id": 0,
|
||
"response": "我认为...",
|
||
"prompt": "你对这件事有什么看法?",
|
||
"timestamp": "2025-12-08T10:00:00",
|
||
"platform": "reddit"
|
||
},
|
||
...
|
||
]
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
platform = data.get('platform') # 不指定则返回两个平台的历史
|
||
agent_id = data.get('agent_id')
|
||
limit = data.get('limit', 100)
|
||
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
history = SimulationRunner.get_interview_history(
|
||
simulation_id=simulation_id,
|
||
platform=platform,
|
||
agent_id=agent_id,
|
||
limit=limit
|
||
)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"count": len(history),
|
||
"history": history
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Interview历史失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/env-status', methods=['POST'])
|
||
def get_env_status():
|
||
"""
|
||
获取模拟环境状态
|
||
|
||
检查模拟环境是否存活(可以接收Interview命令)
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx" // 必填,模拟ID
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"simulation_id": "sim_xxxx",
|
||
"env_alive": true,
|
||
"twitter_available": true,
|
||
"reddit_available": true,
|
||
"message": "环境正在运行,可以接收Interview命令"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
env_alive = SimulationRunner.check_env_alive(simulation_id)
|
||
|
||
# 获取更详细的状态信息
|
||
env_status = SimulationRunner.get_env_status_detail(simulation_id)
|
||
|
||
if env_alive:
|
||
message = t('api.envRunning')
|
||
else:
|
||
message = t('api.envNotRunningShort')
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"data": {
|
||
"simulation_id": simulation_id,
|
||
"env_alive": env_alive,
|
||
"twitter_available": env_status.get("twitter_available", False),
|
||
"reddit_available": env_status.get("reddit_available", False),
|
||
"message": message
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取环境状态失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|
||
|
||
|
||
@simulation_bp.route('/close-env', methods=['POST'])
|
||
def close_simulation_env():
|
||
"""
|
||
关闭模拟环境
|
||
|
||
向模拟发送关闭环境命令,使其优雅退出等待命令模式。
|
||
|
||
注意:这不同于 /stop 接口,/stop 会强制终止进程,
|
||
而此接口会让模拟优雅地关闭环境并退出。
|
||
|
||
请求(JSON):
|
||
{
|
||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||
"timeout": 30 // 可选,超时时间(秒),默认30
|
||
}
|
||
|
||
返回:
|
||
{
|
||
"success": true,
|
||
"data": {
|
||
"message": "环境关闭命令已发送",
|
||
"result": {...},
|
||
"timestamp": "2025-12-08T10:00:01"
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
data = request.get_json() or {}
|
||
|
||
simulation_id = data.get('simulation_id')
|
||
timeout = data.get('timeout', 30)
|
||
|
||
if not simulation_id:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": t('api.requireSimulationId')
|
||
}), 400
|
||
|
||
result = SimulationRunner.close_simulation_env(
|
||
simulation_id=simulation_id,
|
||
timeout=timeout
|
||
)
|
||
|
||
# 更新模拟状态
|
||
manager = SimulationManager()
|
||
state = manager.get_simulation(simulation_id)
|
||
if state:
|
||
state.status = SimulationStatus.COMPLETED
|
||
manager._save_simulation_state(state)
|
||
|
||
return jsonify({
|
||
"success": result.get("success", False),
|
||
"data": result
|
||
})
|
||
|
||
except ValueError as e:
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e)
|
||
}), 400
|
||
|
||
except Exception as e:
|
||
logger.error(f"关闭环境失败: {str(e)}")
|
||
return jsonify({
|
||
"success": False,
|
||
"error": str(e),
|
||
"traceback": traceback.format_exc()
|
||
}), 500
|