talemate/docs/dev/client/example/runpod_vllm/__init__.py
veguAI 2f07248211
Prep 0.20.0 (#77)
* fix issue where recent save cover images would sometimes not load

* paraphrase prompt tweaks

* action_to_narration regenerate compatibility fixes

* sim suite add asnwer question instruction

* more sim suite tweaks

* refactor agent details display in agent bar

* visual agent progres (a1111 support)

* visual gen prompt tweaks

* openai compat client pass max_tokens

* world state sequential reinforcement max tokens tightened

* improve item names

* Improve item names

* attempt to remove "changed from.." notes when altering an existing character sheet

* prompt improvements for single character portraits

* visual agent progress

* fix issue where character.update wouldn't update long-term memory

* remove experimental flag for now

* add better instructions for updating existing character sheet

* background processing for agents, visual and tts

* fix selected voice not saving between restarts for elevenlabs

* lessen timeout

* clean up agent status logic

* conditional agent configs

* comfyui support

* visualization queue

* refactor visual styles, comfyui progress

* regen images
auto cover image assign
websocket handler plugin abstraction
agent websocket handler

* automatic1111 fixes
agent status and ready checks

* tweaks to character portrait prompt

* system prompt for visualize

* textgenwebui use temp smoothing on yi models

* comment out api key for now

* fixes issues with openai compat client for retaining api key and auto fixing urls

* update_reinforcment tweaks

* agent status emit from one place

* emit agent status as asyncio task

* remove debug output

* tts add openai support

* openai img gen support

* fix issue with confyui checkbox list not loading

* tts model selection for openai

* narrate_query include character sheet if character is referenced in query
improve visual character portrit generation prompt

* client implementation extra field support and runpod vllm client example

* relock

* fix issue where changing context length would cause next generation to error

* visual agent tweaks and auto gen character cover image in sim suite

* fix issue with readyness lock when there werent any clients defined

* load scene readiness fixes

* linting

* docs

* notes for the runpod vllm example
2024-02-16 13:57:45 +02:00

130 lines
4 KiB
Python

"""
An attempt to write a client against the runpod serverless vllm worker.
This is close to functional, but since runpod serverless gpu availability is currently terrible, i have
been unable to properly test it.
Putting it here for now since i think it makes a decent example of how to write a client against a new service.
"""
import pydantic
import structlog
import runpod
import asyncio
import aiohttp
from talemate.client.base import ClientBase, ExtraField
from talemate.client.registry import register
from talemate.emit import emit
from talemate.config import Client as BaseClientConfig
log = structlog.get_logger("talemate.client.runpod_vllm")
class Defaults(pydantic.BaseModel):
max_token_length: int = 4096
model: str = ""
runpod_id: str = ""
class ClientConfig(BaseClientConfig):
runpod_id: str = ""
@register()
class RunPodVLLMClient(ClientBase):
client_type = "runpod_vllm"
conversation_retries = 5
config_cls = ClientConfig
class Meta(ClientBase.Meta):
title: str = "Runpod VLLM"
name_prefix: str = "Runpod VLLM"
enable_api_auth: bool = True
manual_model: bool = True
defaults: Defaults = Defaults()
extra_fields: dict[str, ExtraField] = {
"runpod_id": ExtraField(
name="runpod_id",
type="text",
label="Runpod ID",
required=True,
description="The Runpod ID to connect to.",
)
}
def __init__(self, model=None, runpod_id=None, **kwargs):
self.model_name = model
self.runpod_id = runpod_id
super().__init__(**kwargs)
@property
def experimental(self):
return False
def set_client(self, **kwargs):
log.debug("set_client", kwargs=kwargs, runpod_id=self.runpod_id)
self.runpod_id = kwargs.get("runpod_id", self.runpod_id)
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p", "max_tokens"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
return self.model_name
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
prompt = prompt.strip()
self.log.debug("generate", prompt=prompt[:128] + " ...", parameters=parameters)
try:
async with aiohttp.ClientSession() as session:
endpoint = runpod.AsyncioEndpoint(self.runpod_id, session)
run_request = await endpoint.run({
"input": {
"prompt": prompt,
}
#"parameters": parameters
})
while (await run_request.status()) not in ["COMPLETED", "FAILED", "CANCELLED"]:
status = await run_request.status()
log.debug("generate", status=status)
await asyncio.sleep(0.1)
status = await run_request.status()
log.debug("generate", status=status)
response = await run_request.output()
log.debug("generate", response=response)
return response["choices"][0]["tokens"][0]
except Exception as e:
self.log.error("generate error", e=e)
emit(
"status", message="Error during generation (check logs)", status="error"
)
return ""
def reconfigure(self, **kwargs):
if kwargs.get("model"):
self.model_name = kwargs["model"]
if "runpod_id" in kwargs:
self.api_auth = kwargs["runpod_id"]
log.warning("reconfigure", kwargs=kwargs)
self.set_client(**kwargs)