mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 23:42:21 +00:00
add clone public chat service logic
This commit is contained in:
parent
37526b74a9
commit
1ab084aa31
1 changed files with 215 additions and 0 deletions
|
|
@ -197,3 +197,218 @@ async def get_thread_by_share_token(
|
|||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def get_user_default_search_space(
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
) -> int | None:
|
||||
"""
|
||||
Get user's default search space for cloning.
|
||||
|
||||
Returns the first search space where user is owner, or None if not found.
|
||||
"""
|
||||
from app.db import SearchSpaceMembership
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpaceMembership)
|
||||
.filter(
|
||||
SearchSpaceMembership.user_id == user_id,
|
||||
SearchSpaceMembership.is_owner.is_(True),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
membership = result.scalars().first()
|
||||
|
||||
if membership:
|
||||
return membership.search_space_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def clone_public_chat(
|
||||
session: AsyncSession,
|
||||
share_token: str,
|
||||
user_id: UUID,
|
||||
) -> dict:
|
||||
"""
|
||||
Clone a public chat to user's account.
|
||||
|
||||
Creates a new private thread with all messages and podcasts.
|
||||
"""
|
||||
import copy
|
||||
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
)
|
||||
|
||||
source_thread = await get_thread_by_share_token(session, share_token)
|
||||
if not source_thread:
|
||||
await _create_clone_failure_notification(
|
||||
session, user_id, share_token, "Chat not found or no longer public"
|
||||
)
|
||||
return {"status": "error", "error": "Chat not found or no longer public"}
|
||||
|
||||
try:
|
||||
target_search_space_id = await get_user_default_search_space(session, user_id)
|
||||
|
||||
if target_search_space_id is None:
|
||||
await _create_clone_failure_notification(
|
||||
session, user_id, share_token, "No search space found"
|
||||
)
|
||||
return {"status": "error", "error": "No search space found"}
|
||||
|
||||
new_thread = NewChatThread(
|
||||
title=source_thread.title,
|
||||
archived=False,
|
||||
visibility=ChatVisibility.PRIVATE,
|
||||
search_space_id=target_search_space_id,
|
||||
created_by_id=user_id,
|
||||
public_share_enabled=False,
|
||||
)
|
||||
session.add(new_thread)
|
||||
await session.flush()
|
||||
|
||||
podcast_id_map: dict[int, int] = {}
|
||||
|
||||
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
|
||||
new_content = copy.deepcopy(msg.content)
|
||||
|
||||
if isinstance(new_content, list):
|
||||
for part in new_content:
|
||||
if (
|
||||
isinstance(part, dict)
|
||||
and part.get("type") == "tool-call"
|
||||
and part.get("toolName") == "generate_podcast"
|
||||
):
|
||||
result = part.get("result", {})
|
||||
old_podcast_id = result.get("podcast_id")
|
||||
if old_podcast_id and old_podcast_id not in podcast_id_map:
|
||||
new_podcast_id = await _clone_podcast(
|
||||
session,
|
||||
old_podcast_id,
|
||||
target_search_space_id,
|
||||
)
|
||||
if new_podcast_id:
|
||||
podcast_id_map[old_podcast_id] = new_podcast_id
|
||||
|
||||
if old_podcast_id and old_podcast_id in podcast_id_map:
|
||||
result["podcast_id"] = podcast_id_map[old_podcast_id]
|
||||
|
||||
new_message = NewChatMessage(
|
||||
thread_id=new_thread.id,
|
||||
role=msg.role,
|
||||
content=new_content,
|
||||
author_id=msg.author_id,
|
||||
created_at=msg.created_at,
|
||||
)
|
||||
session.add(new_message)
|
||||
|
||||
await session.commit()
|
||||
|
||||
await _create_clone_success_notification(
|
||||
session,
|
||||
user_id,
|
||||
new_thread.id,
|
||||
target_search_space_id,
|
||||
source_thread.title,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"thread_id": new_thread.id,
|
||||
"search_space_id": target_search_space_id,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await _create_clone_failure_notification(session, user_id, share_token, str(e))
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
async def _clone_podcast(
|
||||
session: AsyncSession,
|
||||
podcast_id: int,
|
||||
target_search_space_id: int,
|
||||
) -> int | None:
|
||||
"""Clone a podcast record and its audio file."""
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from app.db import Podcast
|
||||
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
original = result.scalars().first()
|
||||
if not original:
|
||||
return None
|
||||
|
||||
new_file_path = None
|
||||
if original.file_location:
|
||||
original_path = Path(original.file_location)
|
||||
if original_path.exists():
|
||||
new_filename = f"{uuid.uuid4()}_podcast.mp3"
|
||||
new_dir = Path("podcasts")
|
||||
new_dir.mkdir(parents=True, exist_ok=True)
|
||||
new_file_path = str(new_dir / new_filename)
|
||||
shutil.copy2(original.file_location, new_file_path)
|
||||
|
||||
new_podcast = Podcast(
|
||||
title=original.title,
|
||||
podcast_transcript=original.podcast_transcript,
|
||||
file_location=new_file_path,
|
||||
search_space_id=target_search_space_id,
|
||||
)
|
||||
session.add(new_podcast)
|
||||
await session.flush()
|
||||
|
||||
return new_podcast.id
|
||||
|
||||
|
||||
async def _create_clone_success_notification(
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
thread_id: int,
|
||||
search_space_id: int,
|
||||
original_title: str,
|
||||
) -> None:
|
||||
"""Create success notification for clone operation."""
|
||||
from app.db import Notification
|
||||
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
type="chat_cloned",
|
||||
title="Chat copied successfully",
|
||||
message=f"Your copy of '{original_title}' is ready",
|
||||
notification_metadata={
|
||||
"thread_id": thread_id,
|
||||
"search_space_id": search_space_id,
|
||||
},
|
||||
)
|
||||
session.add(notification)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _create_clone_failure_notification(
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
share_token: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Create failure notification for clone operation."""
|
||||
from app.db import Notification
|
||||
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
type="chat_clone_failed",
|
||||
title="Failed to copy chat",
|
||||
message="Could not copy the chat. Please try again.",
|
||||
notification_metadata={
|
||||
"share_token": share_token,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
session.add(notification)
|
||||
await session.commit()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue