mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-09-02 18:39:09 +00:00
disable revision during image prompt generation
This commit is contained in:
parent
5e9e89d452
commit
f7c61aff02
2 changed files with 43 additions and 22 deletions
|
@ -28,8 +28,9 @@ from talemate.util import count_tokens
|
||||||
from talemate.prompts import Prompt
|
from talemate.prompts import Prompt
|
||||||
from talemate.exceptions import GenerationCancelled
|
from talemate.exceptions import GenerationCancelled
|
||||||
import talemate.game.focal as focal
|
import talemate.game.focal as focal
|
||||||
from talemate.status import LoadingStatus, set_loading
|
from talemate.status import LoadingStatus
|
||||||
from talemate.world_state.templates.content import PhraseDetection
|
from talemate.world_state.templates.content import PhraseDetection
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from talemate.tale_mate import Character, Scene
|
from talemate.tale_mate import Character, Scene
|
||||||
|
@ -51,6 +52,16 @@ detect_bad_prose_condition = AgentActionConditional(
|
||||||
value=True,
|
value=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
revision_disabled_context = ContextVar("revision_disabled", default=False)
|
||||||
|
|
||||||
|
class RevisionDisabled:
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.token = revision_disabled_context.set(True)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
revision_disabled_context.reset(self.token)
|
||||||
|
|
||||||
class RevisionMixin:
|
class RevisionMixin:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -226,6 +237,13 @@ class RevisionMixin:
|
||||||
if not self.revision_enabled:
|
if not self.revision_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if revision_disabled_context.get():
|
||||||
|
log.debug("revision_on_generation: revision disabled through context", emission=emission)
|
||||||
|
return
|
||||||
|
except LookupError:
|
||||||
|
pass
|
||||||
|
|
||||||
log.info("revise generation", emission=emission)
|
log.info("revise generation", emission=emission)
|
||||||
|
|
||||||
edited = []
|
edited = []
|
||||||
|
|
|
@ -15,6 +15,7 @@ from talemate.agents.base import (
|
||||||
set_processing,
|
set_processing,
|
||||||
)
|
)
|
||||||
from talemate.agents.registry import register
|
from talemate.agents.registry import register
|
||||||
|
from talemate.agents.editor.revision import RevisionDisabled
|
||||||
from talemate.client.base import ClientBase
|
from talemate.client.base import ClientBase
|
||||||
from talemate.config import load_config
|
from talemate.config import load_config
|
||||||
from talemate.emit import emit
|
from talemate.emit import emit
|
||||||
|
@ -504,6 +505,7 @@ class VisualBase(Agent):
|
||||||
@set_processing
|
@set_processing
|
||||||
async def generate_environment_prompt(self, instructions: str = None):
|
async def generate_environment_prompt(self, instructions: str = None):
|
||||||
|
|
||||||
|
with RevisionDisabled():
|
||||||
response = await Prompt.request(
|
response = await Prompt.request(
|
||||||
"visual.generate-environment-prompt",
|
"visual.generate-environment-prompt",
|
||||||
self.client,
|
self.client,
|
||||||
|
@ -523,6 +525,7 @@ class VisualBase(Agent):
|
||||||
|
|
||||||
character = self.scene.get_character(character_name)
|
character = self.scene.get_character(character_name)
|
||||||
|
|
||||||
|
with RevisionDisabled():
|
||||||
response = await Prompt.request(
|
response = await Prompt.request(
|
||||||
"visual.generate-character-prompt",
|
"visual.generate-character-prompt",
|
||||||
self.client,
|
self.client,
|
||||||
|
|
Loading…
Add table
Reference in a new issue