diff --git a/src/talemate/agents/editor/revision.py b/src/talemate/agents/editor/revision.py index 187e78fd..1d2f42b6 100644 --- a/src/talemate/agents/editor/revision.py +++ b/src/talemate/agents/editor/revision.py @@ -28,8 +28,9 @@ from talemate.util import count_tokens from talemate.prompts import Prompt from talemate.exceptions import GenerationCancelled import talemate.game.focal as focal -from talemate.status import LoadingStatus, set_loading +from talemate.status import LoadingStatus from talemate.world_state.templates.content import PhraseDetection +from contextvars import ContextVar if TYPE_CHECKING: from talemate.tale_mate import Character, Scene @@ -51,6 +52,16 @@ detect_bad_prose_condition = AgentActionConditional( value=True, ) +revision_disabled_context = ContextVar("revision_disabled", default=False) + +class RevisionDisabled: + + def __enter__(self): + self.token = revision_disabled_context.set(True) + + def __exit__(self, exc_type, exc_value, traceback): + revision_disabled_context.reset(self.token) + class RevisionMixin: """ @@ -226,6 +237,13 @@ class RevisionMixin: if not self.revision_enabled: return + try: + if revision_disabled_context.get(): + log.debug("revision_on_generation: revision disabled through context", emission=emission) + return + except LookupError: + pass + log.info("revise generation", emission=emission) edited = [] diff --git a/src/talemate/agents/visual/__init__.py b/src/talemate/agents/visual/__init__.py index cbf58c7c..b98f8e20 100644 --- a/src/talemate/agents/visual/__init__.py +++ b/src/talemate/agents/visual/__init__.py @@ -15,6 +15,7 @@ from talemate.agents.base import ( set_processing, ) from talemate.agents.registry import register +from talemate.agents.editor.revision import RevisionDisabled from talemate.client.base import ClientBase from talemate.config import load_config from talemate.emit import emit @@ -504,15 +505,16 @@ class VisualBase(Agent): @set_processing async def generate_environment_prompt(self, instructions: str = None): - response = await Prompt.request( - "visual.generate-environment-prompt", - self.client, - "visualize", - { - "scene": self.scene, - "max_tokens": self.client.max_token_length, - }, - ) + with RevisionDisabled(): + response = await Prompt.request( + "visual.generate-environment-prompt", + self.client, + "visualize", + { + "scene": self.scene, + "max_tokens": self.client.max_token_length, + }, + ) return response.strip() @@ -523,18 +525,19 @@ class VisualBase(Agent): character = self.scene.get_character(character_name) - response = await Prompt.request( - "visual.generate-character-prompt", - self.client, - "visualize", - { - "scene": self.scene, - "character_name": character_name, - "character": character, - "max_tokens": self.client.max_token_length, - "instructions": instructions or "", - }, - ) + with RevisionDisabled(): + response = await Prompt.request( + "visual.generate-character-prompt", + self.client, + "visualize", + { + "scene": self.scene, + "character_name": character_name, + "character": character, + "max_tokens": self.client.max_token_length, + "instructions": instructions or "", + }, + ) return response.strip()