mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-09-17 01:19:44 +00:00
0.25.2 (#108)
* fix typo * fix openai compat config save issue maybe * fix api_handles_prompt_template no longer saving changes after last fix * koboldcpp client * default to kobold ai api * linting * conversation cleanup tweak * 0.25.2 * allowed hosts to all on dev instance * ensure numbers on parameters when sending edited values * fix prompt parameter issues * remove debug message
This commit is contained in:
parent
60cb271e30
commit
a28cf2a029
12 changed files with 242 additions and 22 deletions
|
@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "talemate"
|
name = "talemate"
|
||||||
version = "0.25.1"
|
version = "0.25.2"
|
||||||
description = "AI-backed roleplay and narrative tools"
|
description = "AI-backed roleplay and narrative tools"
|
||||||
authors = ["FinalWombat"]
|
authors = ["FinalWombat"]
|
||||||
license = "GNU Affero General Public License v3.0"
|
license = "GNU Affero General Public License v3.0"
|
||||||
|
|
|
@ -2,4 +2,4 @@ from .agents import Agent
|
||||||
from .client import TextGeneratorWebuiClient
|
from .client import TextGeneratorWebuiClient
|
||||||
from .tale_mate import *
|
from .tale_mate import *
|
||||||
|
|
||||||
VERSION = "0.25.1"
|
VERSION = "0.25.2"
|
||||||
|
|
|
@ -668,7 +668,9 @@ class ConversationAgent(Agent):
|
||||||
|
|
||||||
total_result = util.handle_endofline_special_delimiter(total_result)
|
total_result = util.handle_endofline_special_delimiter(total_result)
|
||||||
|
|
||||||
if total_result.startswith(":\n"):
|
log.info("conversation agent", total_result=total_result)
|
||||||
|
|
||||||
|
if total_result.startswith(":\n") or total_result.startswith(": "):
|
||||||
total_result = total_result[2:]
|
total_result = total_result[2:]
|
||||||
|
|
||||||
# movie script format
|
# movie script format
|
||||||
|
|
|
@ -5,9 +5,10 @@ from talemate.client.anthropic import AnthropicClient
|
||||||
from talemate.client.cohere import CohereClient
|
from talemate.client.cohere import CohereClient
|
||||||
from talemate.client.google import GoogleClient
|
from talemate.client.google import GoogleClient
|
||||||
from talemate.client.groq import GroqClient
|
from talemate.client.groq import GroqClient
|
||||||
|
from talemate.client.koboldccp import KoboldCppClient
|
||||||
from talemate.client.lmstudio import LMStudioClient
|
from talemate.client.lmstudio import LMStudioClient
|
||||||
from talemate.client.mistral import MistralAIClient
|
from talemate.client.mistral import MistralAIClient
|
||||||
from talemate.client.openai import OpenAIClient
|
from talemate.client.openai import OpenAIClient
|
||||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
|
@ -1,16 +1,201 @@
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
from abc import ABC, abstractmethod
|
import re
|
||||||
from typing import Callable, Union
|
|
||||||
|
|
||||||
import requests
|
# import urljoin
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
import httpx
|
||||||
|
import structlog
|
||||||
|
|
||||||
import talemate.client.system_prompts as system_prompts
|
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField
|
||||||
import talemate.util as util
|
|
||||||
from talemate.client.registry import register
|
from talemate.client.registry import register
|
||||||
from talemate.client.textgenwebui import RESTTaleMateClient
|
|
||||||
from talemate.emit import Emission, emit
|
|
||||||
|
|
||||||
# NOT IMPLEMENTED AT THIS POINT
|
log = structlog.get_logger("talemate.client.koboldcpp")
|
||||||
|
|
||||||
|
|
||||||
|
class KoboldCppClientDefaults(Defaults):
|
||||||
|
api_key: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@register()
|
||||||
|
class KoboldCppClient(ClientBase):
|
||||||
|
auto_determine_prompt_template: bool = True
|
||||||
|
client_type = "koboldcpp"
|
||||||
|
|
||||||
|
class Meta(ClientBase.Meta):
|
||||||
|
name_prefix: str = "KoboldCpp"
|
||||||
|
title: str = "KoboldCpp"
|
||||||
|
enable_api_auth: bool = True
|
||||||
|
defaults: KoboldCppClientDefaults = KoboldCppClientDefaults()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def request_headers(self):
|
||||||
|
headers = {}
|
||||||
|
headers["Content-Type"] = "application/json"
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_openai(self) -> bool:
|
||||||
|
"""
|
||||||
|
kcpp has two apis
|
||||||
|
|
||||||
|
open-ai implementation at /v1
|
||||||
|
their own implenation at /api/v1
|
||||||
|
"""
|
||||||
|
return "/api/v1" not in self.api_url
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_url_for_model(self) -> str:
|
||||||
|
if self.is_openai:
|
||||||
|
# join /model to url
|
||||||
|
return urljoin(self.api_url, "models")
|
||||||
|
else:
|
||||||
|
# join /models to url
|
||||||
|
return urljoin(self.api_url, "model")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_url_for_generation(self) -> str:
|
||||||
|
if self.is_openai:
|
||||||
|
# join /v1/completions
|
||||||
|
return urljoin(self.api_url, "completions")
|
||||||
|
else:
|
||||||
|
# join /api/v1/generate
|
||||||
|
return urljoin(self.api_url, "generate")
|
||||||
|
|
||||||
|
def api_endpoint_specified(self, url: str) -> bool:
|
||||||
|
return "/v1" in self.api_url
|
||||||
|
|
||||||
|
def ensure_api_endpoint_specified(self):
|
||||||
|
if not self.api_endpoint_specified(self.api_url):
|
||||||
|
# url doesn't specify the api endpoint
|
||||||
|
# use the koboldcpp openai api
|
||||||
|
self.api_url = urljoin(self.api_url.rstrip("/") + "/", "/api/v1/")
|
||||||
|
if not self.api_url.endswith("/"):
|
||||||
|
self.api_url += "/"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.api_key = kwargs.pop("api_key", "")
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.ensure_api_endpoint_specified()
|
||||||
|
|
||||||
|
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||||
|
super().tune_prompt_parameters(parameters, kind)
|
||||||
|
if not self.is_openai:
|
||||||
|
# adjustments for united api
|
||||||
|
parameters["max_length"] = parameters.pop("max_tokens")
|
||||||
|
parameters["max_context_length"] = self.max_token_length
|
||||||
|
if "repetition_penalty_range" in parameters:
|
||||||
|
parameters["rep_pen_range"] = parameters.pop("repetition_penalty_range")
|
||||||
|
if "repetition_penalty" in parameters:
|
||||||
|
parameters["rep_pen"] = parameters.pop("repetition_penalty")
|
||||||
|
if parameters.get("stop_sequence"):
|
||||||
|
parameters["stop_sequence"] = parameters.pop("stopping_strings")
|
||||||
|
|
||||||
|
if parameters.get("extra_stopping_strings"):
|
||||||
|
if "stop_sequence" in parameters:
|
||||||
|
parameters["stop_sequence"] += parameters.pop("extra_stopping_strings")
|
||||||
|
else:
|
||||||
|
parameters["stop_sequence"] = parameters.pop("extra_stopping_strings")
|
||||||
|
|
||||||
|
|
||||||
|
allowed_params = [
|
||||||
|
"max_length",
|
||||||
|
"max_context_length",
|
||||||
|
"rep_pen",
|
||||||
|
"rep_pen_range",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"temperature",
|
||||||
|
"stop_sequence",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# adjustments for openai api
|
||||||
|
if "repetition_penalty" in parameters:
|
||||||
|
parameters["presence_penalty"] = parameters.pop(
|
||||||
|
"repetition_penalty"
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"]
|
||||||
|
|
||||||
|
# drop unsupported params
|
||||||
|
for param in list(parameters.keys()):
|
||||||
|
if param not in allowed_params:
|
||||||
|
del parameters[param]
|
||||||
|
|
||||||
|
def set_client(self, **kwargs):
|
||||||
|
self.api_key = kwargs.get("api_key", self.api_key)
|
||||||
|
self.ensure_api_endpoint_specified()
|
||||||
|
|
||||||
|
async def get_model_name(self):
|
||||||
|
self.ensure_api_endpoint_specified()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
self.api_url_for_model,
|
||||||
|
timeout=2,
|
||||||
|
headers=self.request_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise KeyError(f"Could not find model info at: {self.api_url_for_model}")
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
if self.is_openai:
|
||||||
|
# {"object": "list", "data": [{"id": "koboldcpp/dolphin-2.8-mistral-7b", "object": "model", "created": 1, "owned_by": "koboldcpp", "permission": [], "root": "koboldcpp"}]}
|
||||||
|
model_name = response_data.get("data")[0].get("id")
|
||||||
|
else:
|
||||||
|
# {"result": "koboldcpp/dolphin-2.8-mistral-7b"}
|
||||||
|
model_name = response_data.get("result")
|
||||||
|
|
||||||
|
# split by "/" and take last
|
||||||
|
if model_name:
|
||||||
|
model_name = model_name.split("/")[-1]
|
||||||
|
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||||
|
"""
|
||||||
|
Generates text from the given prompt and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parameters["prompt"] = prompt.strip(" ")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.api_url_for_generation,
|
||||||
|
json=parameters,
|
||||||
|
timeout=None,
|
||||||
|
headers=self.request_headers,
|
||||||
|
)
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.is_openai:
|
||||||
|
return response_data["choices"][0]["text"]
|
||||||
|
else:
|
||||||
|
return response_data["results"][0]["text"]
|
||||||
|
except (TypeError, KeyError) as exc:
|
||||||
|
log.error("Failed to generate text", exc=exc, response_data=response_data, response_status=response.status_code)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||||
|
"""
|
||||||
|
adjusts temperature and repetition_penalty
|
||||||
|
by random values using the base value as a center
|
||||||
|
"""
|
||||||
|
|
||||||
|
temp = prompt_config["temperature"]
|
||||||
|
rep_pen = prompt_config["rep_pen"]
|
||||||
|
|
||||||
|
min_offset = offset * 0.3
|
||||||
|
|
||||||
|
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||||
|
prompt_config["rep_pen"] = random.uniform(
|
||||||
|
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||||
|
)
|
||||||
|
|
||||||
|
def reconfigure(self, **kwargs):
|
||||||
|
if "api_key" in kwargs:
|
||||||
|
self.api_key = kwargs.pop("api_key")
|
||||||
|
|
||||||
|
super().reconfigure(**kwargs)
|
||||||
|
|
|
@ -11,6 +11,20 @@ class TestPromptPayload(pydantic.BaseModel):
|
||||||
kind: str
|
kind: str
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_number(v):
|
||||||
|
"""
|
||||||
|
if v is a str but digit turn into into or float
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(v, str):
|
||||||
|
if v.isdigit():
|
||||||
|
return int(v)
|
||||||
|
try:
|
||||||
|
return float(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
class DevToolsPlugin:
|
class DevToolsPlugin:
|
||||||
router = "devtools"
|
router = "devtools"
|
||||||
|
|
||||||
|
@ -34,7 +48,7 @@ class DevToolsPlugin:
|
||||||
log.info(
|
log.info(
|
||||||
"Testing prompt",
|
"Testing prompt",
|
||||||
payload={
|
payload={
|
||||||
k: v for k, v in payload.generation_parameters.items() if k != "prompt"
|
k: ensure_number(v) for k, v in payload.generation_parameters.items() if k != "prompt"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
4
talemate_frontend/package-lock.json
generated
4
talemate_frontend/package-lock.json
generated
|
@ -1,12 +1,12 @@
|
||||||
{
|
{
|
||||||
"name": "talemate_frontend",
|
"name": "talemate_frontend",
|
||||||
"version": "0.25.1",
|
"version": "0.25.2",
|
||||||
"lockfileVersion": 2,
|
"lockfileVersion": 2,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "talemate_frontend",
|
"name": "talemate_frontend",
|
||||||
"version": "0.25.1",
|
"version": "0.25.2",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@codemirror/lang-markdown": "^6.2.5",
|
"@codemirror/lang-markdown": "^6.2.5",
|
||||||
"@codemirror/theme-one-dark": "^6.1.2",
|
"@codemirror/theme-one-dark": "^6.1.2",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "talemate_frontend",
|
"name": "talemate_frontend",
|
||||||
"version": "0.25.1",
|
"version": "0.25.2",
|
||||||
"private": true,
|
"private": true,
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"serve": "vue-cli-service serve",
|
"serve": "vue-cli-service serve",
|
||||||
|
|
|
@ -244,6 +244,13 @@ export default {
|
||||||
client.api_key = data.api_key;
|
client.api_key = data.api_key;
|
||||||
client.double_coercion = data.data.double_coercion;
|
client.double_coercion = data.data.double_coercion;
|
||||||
client.data = data.data;
|
client.data = data.data;
|
||||||
|
for (let key in client.data.meta.extra_fields) {
|
||||||
|
if (client.data[key] === null || client.data[key] === undefined) {
|
||||||
|
client.data[key] = client.data.meta.defaults[key];
|
||||||
|
}
|
||||||
|
client[key] = client.data[key];
|
||||||
|
}
|
||||||
|
|
||||||
} else if(!client) {
|
} else if(!client) {
|
||||||
console.log("Adding new client", data);
|
console.log("Adding new client", data);
|
||||||
|
|
||||||
|
@ -259,6 +266,16 @@ export default {
|
||||||
double_coercion: data.data.double_coercion,
|
double_coercion: data.data.double_coercion,
|
||||||
data: data.data,
|
data: data.data,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// apply extra field defaults
|
||||||
|
let client = this.state.clients[this.state.clients.length - 1];
|
||||||
|
for (let key in client.data.meta.extra_fields) {
|
||||||
|
if (client.data[key] === null || client.data[key] === undefined) {
|
||||||
|
client.data[key] = client.data.meta.defaults[key];
|
||||||
|
}
|
||||||
|
client[key] = client.data[key];
|
||||||
|
}
|
||||||
|
|
||||||
// sort the clients by name
|
// sort the clients by name
|
||||||
this.state.clients.sort((a, b) => (a.name > b.name) ? 1 : -1);
|
this.state.clients.sort((a, b) => (a.name > b.name) ? 1 : -1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,9 +56,9 @@
|
||||||
</v-row>
|
</v-row>
|
||||||
<v-row v-for="field in clientMeta().extra_fields" :key="field.name">
|
<v-row v-for="field in clientMeta().extra_fields" :key="field.name">
|
||||||
<v-col cols="12">
|
<v-col cols="12">
|
||||||
<v-text-field v-model="client.data[field.name]" v-if="field.type === 'text'" :label="field.label"
|
<v-text-field v-model="client[field.name]" v-if="field.type === 'text'" :label="field.label"
|
||||||
:rules="[rules.required]" :hint="field.description"></v-text-field>
|
:rules="[rules.required]" :hint="field.description"></v-text-field>
|
||||||
<v-checkbox v-else-if="field.type === 'bool'" v-model="client.data[field.name]"
|
<v-checkbox v-else-if="field.type === 'bool'" v-model="client[field.name]"
|
||||||
:label="field.label" :hint="field.description" density="compact"></v-checkbox>
|
:label="field.label" :hint="field.description" density="compact"></v-checkbox>
|
||||||
</v-col>
|
</v-col>
|
||||||
</v-row>
|
</v-row>
|
||||||
|
|
|
@ -248,7 +248,7 @@ export default {
|
||||||
messageHandlers: [],
|
messageHandlers: [],
|
||||||
scene: {},
|
scene: {},
|
||||||
appConfig: {},
|
appConfig: {},
|
||||||
autcompleting: false,
|
autocompleting: false,
|
||||||
autocompletePartialInput: "",
|
autocompletePartialInput: "",
|
||||||
autocompleteCallback: null,
|
autocompleteCallback: null,
|
||||||
autocompleteFocusElement: null,
|
autocompleteFocusElement: null,
|
||||||
|
|
|
@ -9,6 +9,7 @@ module.exports = defineConfig({
|
||||||
},
|
},
|
||||||
|
|
||||||
devServer: {
|
devServer: {
|
||||||
|
allowedHosts: "all",
|
||||||
client: {
|
client: {
|
||||||
overlay: {
|
overlay: {
|
||||||
warnings: false,
|
warnings: false,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue