mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-05-01 21:00:43 +00:00
refactor: reorganize folder structure for better maintainability
Changes: - Move migrations/ under open_notebook/database/migrations/ - Extract AI models to open_notebook/ai/ (Model, ModelManager, provision) - Extract podcasts to open_notebook/podcasts/ (EpisodeProfile, SpeakerProfile, PodcastEpisode) - Reorganize prompts to mirror graphs structure (chat/, source_chat/) This improves code organization by: - Consolidating database concerns (migrations now with database code) - Separating AI infrastructure from domain entities - Isolating podcast feature into its own module - Creating consistent prompt/graph naming conventions All 52 tests pass.
This commit is contained in:
parent
93cda6c42a
commit
ab5560c9a2
48 changed files with 50 additions and 47 deletions
148
open_notebook/podcasts/models.py
Normal file
148
open_notebook/podcasts/models.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from surrealdb import RecordID
|
||||
|
||||
from open_notebook.database.repository import ensure_record_id, repo_query
|
||||
from open_notebook.domain.base import ObjectModel
|
||||
|
||||
|
||||
class EpisodeProfile(ObjectModel):
|
||||
"""
|
||||
Episode Profile - Simplified podcast configuration.
|
||||
Replaces complex 15+ field configuration with user-friendly profiles.
|
||||
"""
|
||||
|
||||
table_name: ClassVar[str] = "episode_profile"
|
||||
|
||||
name: str = Field(..., description="Unique profile name")
|
||||
description: Optional[str] = Field(None, description="Profile description")
|
||||
speaker_config: str = Field(..., description="Reference to speaker profile name")
|
||||
outline_provider: str = Field(..., description="AI provider for outline generation")
|
||||
outline_model: str = Field(..., description="AI model for outline generation")
|
||||
transcript_provider: str = Field(
|
||||
..., description="AI provider for transcript generation"
|
||||
)
|
||||
transcript_model: str = Field(..., description="AI model for transcript generation")
|
||||
default_briefing: str = Field(..., description="Default briefing template")
|
||||
num_segments: int = Field(default=5, description="Number of podcast segments")
|
||||
|
||||
@field_validator("num_segments")
|
||||
@classmethod
|
||||
def validate_segments(cls, v):
|
||||
if not 3 <= v <= 20:
|
||||
raise ValueError("Number of segments must be between 3 and 20")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
async def get_by_name(cls, name: str) -> Optional["EpisodeProfile"]:
|
||||
"""Get episode profile by name"""
|
||||
result = await repo_query(
|
||||
"SELECT * FROM episode_profile WHERE name = $name", {"name": name}
|
||||
)
|
||||
if result:
|
||||
return cls(**result[0])
|
||||
return None
|
||||
|
||||
|
||||
class SpeakerProfile(ObjectModel):
|
||||
"""
|
||||
Speaker Profile - Voice and personality configuration.
|
||||
Supports 1-4 speakers for flexible podcast formats.
|
||||
"""
|
||||
|
||||
table_name: ClassVar[str] = "speaker_profile"
|
||||
|
||||
name: str = Field(..., description="Unique profile name")
|
||||
description: Optional[str] = Field(None, description="Profile description")
|
||||
tts_provider: str = Field(
|
||||
..., description="TTS provider (openai, elevenlabs, etc.)"
|
||||
)
|
||||
tts_model: str = Field(..., description="TTS model name")
|
||||
speakers: List[Dict[str, Any]] = Field(
|
||||
..., description="Array of speaker configurations"
|
||||
)
|
||||
|
||||
@field_validator("speakers")
|
||||
@classmethod
|
||||
def validate_speakers(cls, v):
|
||||
if not 1 <= len(v) <= 4:
|
||||
raise ValueError("Must have between 1 and 4 speakers")
|
||||
|
||||
required_fields = ["name", "voice_id", "backstory", "personality"]
|
||||
for speaker in v:
|
||||
for field in required_fields:
|
||||
if field not in speaker:
|
||||
raise ValueError(f"Speaker missing required field: {field}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
async def get_by_name(cls, name: str) -> Optional["SpeakerProfile"]:
|
||||
"""Get speaker profile by name"""
|
||||
result = await repo_query(
|
||||
"SELECT * FROM speaker_profile WHERE name = $name", {"name": name}
|
||||
)
|
||||
if result:
|
||||
return cls(**result[0])
|
||||
return None
|
||||
|
||||
|
||||
class PodcastEpisode(ObjectModel):
|
||||
"""Enhanced PodcastEpisode with job tracking and metadata"""
|
||||
|
||||
table_name: ClassVar[str] = "episode"
|
||||
|
||||
name: str = Field(..., description="Episode name")
|
||||
episode_profile: Dict[str, Any] = Field(
|
||||
..., description="Episode profile used (stored as object)"
|
||||
)
|
||||
speaker_profile: Dict[str, Any] = Field(
|
||||
..., description="Speaker profile used (stored as object)"
|
||||
)
|
||||
briefing: str = Field(..., description="Full briefing used for generation")
|
||||
content: str = Field(..., description="Source content")
|
||||
audio_file: Optional[str] = Field(
|
||||
default=None, description="Path to generated audio file"
|
||||
)
|
||||
transcript: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Generated transcript"
|
||||
)
|
||||
outline: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Generated outline"
|
||||
)
|
||||
command: Optional[Union[str, RecordID]] = Field(
|
||||
default=None, description="Link to surreal-commands job"
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def get_job_status(self) -> Optional[str]:
|
||||
"""Get the status of the associated command"""
|
||||
if not self.command:
|
||||
return None
|
||||
|
||||
try:
|
||||
from surreal_commands import get_command_status
|
||||
|
||||
status = await get_command_status(str(self.command))
|
||||
return status.status if status else "unknown"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
@field_validator("command", mode="before")
|
||||
@classmethod
|
||||
def parse_command(cls, value):
|
||||
if isinstance(value, str):
|
||||
return ensure_record_id(value)
|
||||
return value
|
||||
|
||||
def _prepare_save_data(self) -> dict:
|
||||
"""Override to ensure command field is always RecordID format for database"""
|
||||
data = super()._prepare_save_data()
|
||||
|
||||
# Ensure command field is RecordID format if not None
|
||||
if data.get("command") is not None:
|
||||
data["command"] = ensure_record_id(data["command"])
|
||||
|
||||
return data
|
||||
Loading…
Add table
Add a link
Reference in a new issue