fix(graph): enforce PascalCase for entity names and SCREAMING_SNAKE_CASE for edge names in ontology validation

This commit is contained in:
666ghj 2026-04-01 17:42:27 +08:00
parent 7c7c7a2c63
commit e3350a919d

View file

@ -4,9 +4,26 @@
"""
import json
import logging
import re
from typing import Dict, Any, List, Optional
from ..utils.llm_client import LLMClient
logger = logging.getLogger(__name__)
def _to_pascal_case(name: str) -> str:
"""将任意格式的名称转换为 PascalCase'works_for' -> 'WorksFor', 'person' -> 'Person'"""
# 按非字母数字字符分割
parts = re.split(r'[^a-zA-Z0-9]+', name)
# 再按 camelCase 边界分割(如 'camelCase' -> ['camel', 'Case']
words = []
for part in parts:
words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_'))
# 每个词首字母大写,过滤空串
result = ''.join(word.capitalize() for word in words if word)
return result if result else 'Unknown'
# 本体生成的系统提示词
ONTOLOGY_SYSTEM_PROMPT = """你是一个专业的知识图谱本体设计专家。你的任务是分析给定的文本内容和模拟需求,设计适合**社交媒体舆论模拟**的实体类型和关系类型。
@ -266,7 +283,16 @@ class OntologyGenerator:
result["analysis_summary"] = ""
# 验证实体类型
# 记录原始名称到 PascalCase 的映射,用于后续修正 edge 的 source_targets 引用
entity_name_map = {}
for entity in result["entity_types"]:
# 强制将 entity name 转为 PascalCaseZep API 要求)
if "name" in entity:
original_name = entity["name"]
entity["name"] = _to_pascal_case(original_name)
if entity["name"] != original_name:
logger.warning(f"Entity type name '{original_name}' auto-converted to '{entity['name']}'")
entity_name_map[original_name] = entity["name"]
if "attributes" not in entity:
entity["attributes"] = []
if "examples" not in entity:
@ -277,6 +303,18 @@ class OntologyGenerator:
# 验证关系类型
for edge in result["edge_types"]:
# 强制将 edge name 转为 SCREAMING_SNAKE_CASEZep API 要求)
if "name" in edge:
original_name = edge["name"]
edge["name"] = original_name.upper()
if edge["name"] != original_name:
logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'")
# 修正 source_targets 中的实体名称引用,与转换后的 PascalCase 保持一致
for st in edge.get("source_targets", []):
if st.get("source") in entity_name_map:
st["source"] = entity_name_map[st["source"]]
if st.get("target") in entity_name_map:
st["target"] = entity_name_map[st["target"]]
if "source_targets" not in edge:
edge["source_targets"] = []
if "attributes" not in edge:
@ -287,7 +325,19 @@ class OntologyGenerator:
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型
MAX_ENTITY_TYPES = 10
MAX_EDGE_TYPES = 10
# 去重:按 name 去重,保留首次出现的
seen_names = set()
deduped = []
for entity in result["entity_types"]:
name = entity.get("name", "")
if name and name not in seen_names:
seen_names.add(name)
deduped.append(entity)
elif name in seen_names:
logger.warning(f"Duplicate entity type '{name}' removed during validation")
result["entity_types"] = deduped
# 兜底类型定义
person_fallback = {
"name": "Person",