mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-09-02 10:29:17 +00:00
* fix issue where recent save cover images would sometimes not load * paraphrase prompt tweaks * action_to_narration regenerate compatibility fixes * sim suite add asnwer question instruction * more sim suite tweaks * refactor agent details display in agent bar * visual agent progres (a1111 support) * visual gen prompt tweaks * openai compat client pass max_tokens * world state sequential reinforcement max tokens tightened * improve item names * Improve item names * attempt to remove "changed from.." notes when altering an existing character sheet * prompt improvements for single character portraits * visual agent progress * fix issue where character.update wouldn't update long-term memory * remove experimental flag for now * add better instructions for updating existing character sheet * background processing for agents, visual and tts * fix selected voice not saving between restarts for elevenlabs * lessen timeout * clean up agent status logic * conditional agent configs * comfyui support * visualization queue * refactor visual styles, comfyui progress * regen images auto cover image assign websocket handler plugin abstraction agent websocket handler * automatic1111 fixes agent status and ready checks * tweaks to character portrait prompt * system prompt for visualize * textgenwebui use temp smoothing on yi models * comment out api key for now * fixes issues with openai compat client for retaining api key and auto fixing urls * update_reinforcment tweaks * agent status emit from one place * emit agent status as asyncio task * remove debug output * tts add openai support * openai img gen support * fix issue with confyui checkbox list not loading * tts model selection for openai * narrate_query include character sheet if character is referenced in query improve visual character portrit generation prompt * client implementation extra field support and runpod vllm client example * relock * fix issue where changing context length would cause next generation to error * visual agent tweaks and auto gen character cover image in sim suite * fix issue with readyness lock when there werent any clients defined * load scene readiness fixes * linting * docs * notes for the runpod vllm example
130 lines
4 KiB
Python
130 lines
4 KiB
Python
"""
|
|
An attempt to write a client against the runpod serverless vllm worker.
|
|
|
|
This is close to functional, but since runpod serverless gpu availability is currently terrible, i have
|
|
been unable to properly test it.
|
|
|
|
Putting it here for now since i think it makes a decent example of how to write a client against a new service.
|
|
"""
|
|
|
|
import pydantic
|
|
import structlog
|
|
import runpod
|
|
import asyncio
|
|
import aiohttp
|
|
from talemate.client.base import ClientBase, ExtraField
|
|
from talemate.client.registry import register
|
|
from talemate.emit import emit
|
|
from talemate.config import Client as BaseClientConfig
|
|
|
|
log = structlog.get_logger("talemate.client.runpod_vllm")
|
|
|
|
class Defaults(pydantic.BaseModel):
|
|
max_token_length: int = 4096
|
|
model: str = ""
|
|
runpod_id: str = ""
|
|
|
|
class ClientConfig(BaseClientConfig):
|
|
runpod_id: str = ""
|
|
|
|
@register()
|
|
class RunPodVLLMClient(ClientBase):
|
|
client_type = "runpod_vllm"
|
|
conversation_retries = 5
|
|
config_cls = ClientConfig
|
|
|
|
class Meta(ClientBase.Meta):
|
|
title: str = "Runpod VLLM"
|
|
name_prefix: str = "Runpod VLLM"
|
|
enable_api_auth: bool = True
|
|
manual_model: bool = True
|
|
defaults: Defaults = Defaults()
|
|
extra_fields: dict[str, ExtraField] = {
|
|
"runpod_id": ExtraField(
|
|
name="runpod_id",
|
|
type="text",
|
|
label="Runpod ID",
|
|
required=True,
|
|
description="The Runpod ID to connect to.",
|
|
)
|
|
}
|
|
|
|
|
|
def __init__(self, model=None, runpod_id=None, **kwargs):
|
|
self.model_name = model
|
|
self.runpod_id = runpod_id
|
|
super().__init__(**kwargs)
|
|
|
|
@property
|
|
def experimental(self):
|
|
return False
|
|
|
|
|
|
def set_client(self, **kwargs):
|
|
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
|
|
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
|
|
|
|
|
|
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
|
super().tune_prompt_parameters(parameters, kind)
|
|
|
|
keys = list(parameters.keys())
|
|
|
|
valid_keys = ["temperature", "top_p", "max_tokens"]
|
|
|
|
for key in keys:
|
|
if key not in valid_keys:
|
|
del parameters[key]
|
|
|
|
async def get_model_name(self):
|
|
return self.model_name
|
|
|
|
async def generate(self, prompt: str, parameters: dict, kind: str):
|
|
"""
|
|
Generates text from the given prompt and parameters.
|
|
"""
|
|
prompt = prompt.strip()
|
|
|
|
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
|
|
|
|
try:
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
|
|
|
|
run_request = await endpoint.run({
|
|
"input": {
|
|
"prompt": prompt,
|
|
}
|
|
#"parameters": parameters
|
|
})
|
|
|
|
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
|
|
status = await run_request.status()
|
|
log.debug("generate", status=status)
|
|
await asyncio.sleep(0.1)
|
|
|
|
status = await run_request.status()
|
|
|
|
log.debug("generate", status=status)
|
|
|
|
response = await run_request.output()
|
|
|
|
log.debug("generate", response=response)
|
|
|
|
return response["choices"][0]["tokens"][0]
|
|
|
|
except Exception as e:
|
|
self.log.error("generate error", e=e)
|
|
emit(
|
|
"status", message="Error during generation (check logs)", status="error"
|
|
)
|
|
return ""
|
|
|
|
def reconfigure(self, **kwargs):
|
|
if kwargs.get("model"):
|
|
self.model_name = kwargs["model"]
|
|
if "runpod_id" in kwargs:
|
|
self.api_auth = kwargs["runpod_id"]
|
|
log.warning("reconfigure", kwargs=kwargs)
|
|
self.set_client(**kwargs)
|