mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-09-17 01:19:44 +00:00
0.25.2 (#108)
* fix typo * fix openai compat config save issue maybe * fix api_handles_prompt_template no longer saving changes after last fix * koboldcpp client * default to kobold ai api * linting * conversation cleanup tweak * 0.25.2 * allowed hosts to all on dev instance * ensure numbers on parameters when sending edited values * fix prompt parameter issues * remove debug message
This commit is contained in:
parent
60cb271e30
commit
a28cf2a029
12 changed files with 242 additions and 22 deletions
|
@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
|
|||
|
||||
[tool.poetry]
|
||||
name = "talemate"
|
||||
version = "0.25.1"
|
||||
version = "0.25.2"
|
||||
description = "AI-backed roleplay and narrative tools"
|
||||
authors = ["FinalWombat"]
|
||||
license = "GNU Affero General Public License v3.0"
|
||||
|
|
|
@ -2,4 +2,4 @@ from .agents import Agent
|
|||
from .client import TextGeneratorWebuiClient
|
||||
from .tale_mate import *
|
||||
|
||||
VERSION = "0.25.1"
|
||||
VERSION = "0.25.2"
|
||||
|
|
|
@ -668,7 +668,9 @@ class ConversationAgent(Agent):
|
|||
|
||||
total_result = util.handle_endofline_special_delimiter(total_result)
|
||||
|
||||
if total_result.startswith(":\n"):
|
||||
log.info("conversation agent", total_result=total_result)
|
||||
|
||||
if total_result.startswith(":\n") or total_result.startswith(": "):
|
||||
total_result = total_result[2:]
|
||||
|
||||
# movie script format
|
||||
|
|
|
@ -5,9 +5,10 @@ from talemate.client.anthropic import AnthropicClient
|
|||
from talemate.client.cohere import CohereClient
|
||||
from talemate.client.google import GoogleClient
|
||||
from talemate.client.groq import GroqClient
|
||||
from talemate.client.koboldccp import KoboldCppClient
|
||||
from talemate.client.lmstudio import LMStudioClient
|
||||
from talemate.client.mistral import MistralAIClient
|
||||
from talemate.client.openai import OpenAIClient
|
||||
from talemate.client.openai_compat import OpenAICompatibleClient
|
||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
|
@ -1,16 +1,201 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Union
|
||||
import re
|
||||
|
||||
import requests
|
||||
# import urljoin
|
||||
from urllib.parse import urljoin
|
||||
import httpx
|
||||
import structlog
|
||||
|
||||
import talemate.client.system_prompts as system_prompts
|
||||
import talemate.util as util
|
||||
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField
|
||||
from talemate.client.registry import register
|
||||
from talemate.client.textgenwebui import RESTTaleMateClient
|
||||
from talemate.emit import Emission, emit
|
||||
|
||||
# NOT IMPLEMENTED AT THIS POINT
|
||||
log = structlog.get_logger("talemate.client.koboldcpp")
|
||||
|
||||
|
||||
class KoboldCppClientDefaults(Defaults):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
@register()
|
||||
class KoboldCppClient(ClientBase):
|
||||
auto_determine_prompt_template: bool = True
|
||||
client_type = "koboldcpp"
|
||||
|
||||
class Meta(ClientBase.Meta):
|
||||
name_prefix: str = "KoboldCpp"
|
||||
title: str = "KoboldCpp"
|
||||
enable_api_auth: bool = True
|
||||
defaults: KoboldCppClientDefaults = KoboldCppClientDefaults()
|
||||
|
||||
@property
|
||||
def request_headers(self):
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
@property
|
||||
def is_openai(self) -> bool:
|
||||
"""
|
||||
kcpp has two apis
|
||||
|
||||
open-ai implementation at /v1
|
||||
their own implenation at /api/v1
|
||||
"""
|
||||
return "/api/v1" not in self.api_url
|
||||
|
||||
@property
|
||||
def api_url_for_model(self) -> str:
|
||||
if self.is_openai:
|
||||
# join /model to url
|
||||
return urljoin(self.api_url, "models")
|
||||
else:
|
||||
# join /models to url
|
||||
return urljoin(self.api_url, "model")
|
||||
|
||||
@property
|
||||
def api_url_for_generation(self) -> str:
|
||||
if self.is_openai:
|
||||
# join /v1/completions
|
||||
return urljoin(self.api_url, "completions")
|
||||
else:
|
||||
# join /api/v1/generate
|
||||
return urljoin(self.api_url, "generate")
|
||||
|
||||
def api_endpoint_specified(self, url: str) -> bool:
|
||||
return "/v1" in self.api_url
|
||||
|
||||
def ensure_api_endpoint_specified(self):
|
||||
if not self.api_endpoint_specified(self.api_url):
|
||||
# url doesn't specify the api endpoint
|
||||
# use the koboldcpp openai api
|
||||
self.api_url = urljoin(self.api_url.rstrip("/") + "/", "/api/v1/")
|
||||
if not self.api_url.endswith("/"):
|
||||
self.api_url += "/"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.api_key = kwargs.pop("api_key", "")
|
||||
super().__init__(**kwargs)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
def tune_prompt_parameters(self, parameters: dict, kind: str):
|
||||
super().tune_prompt_parameters(parameters, kind)
|
||||
if not self.is_openai:
|
||||
# adjustments for united api
|
||||
parameters["max_length"] = parameters.pop("max_tokens")
|
||||
parameters["max_context_length"] = self.max_token_length
|
||||
if "repetition_penalty_range" in parameters:
|
||||
parameters["rep_pen_range"] = parameters.pop("repetition_penalty_range")
|
||||
if "repetition_penalty" in parameters:
|
||||
parameters["rep_pen"] = parameters.pop("repetition_penalty")
|
||||
if parameters.get("stop_sequence"):
|
||||
parameters["stop_sequence"] = parameters.pop("stopping_strings")
|
||||
|
||||
if parameters.get("extra_stopping_strings"):
|
||||
if "stop_sequence" in parameters:
|
||||
parameters["stop_sequence"] += parameters.pop("extra_stopping_strings")
|
||||
else:
|
||||
parameters["stop_sequence"] = parameters.pop("extra_stopping_strings")
|
||||
|
||||
|
||||
allowed_params = [
|
||||
"max_length",
|
||||
"max_context_length",
|
||||
"rep_pen",
|
||||
"rep_pen_range",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"temperature",
|
||||
"stop_sequence",
|
||||
]
|
||||
else:
|
||||
# adjustments for openai api
|
||||
if "repetition_penalty" in parameters:
|
||||
parameters["presence_penalty"] = parameters.pop(
|
||||
"repetition_penalty"
|
||||
)
|
||||
|
||||
allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"]
|
||||
|
||||
# drop unsupported params
|
||||
for param in list(parameters.keys()):
|
||||
if param not in allowed_params:
|
||||
del parameters[param]
|
||||
|
||||
def set_client(self, **kwargs):
|
||||
self.api_key = kwargs.get("api_key", self.api_key)
|
||||
self.ensure_api_endpoint_specified()
|
||||
|
||||
async def get_model_name(self):
|
||||
self.ensure_api_endpoint_specified()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.api_url_for_model,
|
||||
timeout=2,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise KeyError(f"Could not find model info at: {self.api_url_for_model}")
|
||||
|
||||
response_data = response.json()
|
||||
if self.is_openai:
|
||||
# {"object": "list", "data": [{"id": "koboldcpp/dolphin-2.8-mistral-7b", "object": "model", "created": 1, "owned_by": "koboldcpp", "permission": [], "root": "koboldcpp"}]}
|
||||
model_name = response_data.get("data")[0].get("id")
|
||||
else:
|
||||
# {"result": "koboldcpp/dolphin-2.8-mistral-7b"}
|
||||
model_name = response_data.get("result")
|
||||
|
||||
# split by "/" and take last
|
||||
if model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
return model_name
|
||||
|
||||
async def generate(self, prompt: str, parameters: dict, kind: str):
|
||||
"""
|
||||
Generates text from the given prompt and parameters.
|
||||
"""
|
||||
|
||||
parameters["prompt"] = prompt.strip(" ")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.api_url_for_generation,
|
||||
json=parameters,
|
||||
timeout=None,
|
||||
headers=self.request_headers,
|
||||
)
|
||||
response_data = response.json()
|
||||
|
||||
try:
|
||||
if self.is_openai:
|
||||
return response_data["choices"][0]["text"]
|
||||
else:
|
||||
return response_data["results"][0]["text"]
|
||||
except (TypeError, KeyError) as exc:
|
||||
log.error("Failed to generate text", exc=exc, response_data=response_data, response_status=response.status_code)
|
||||
return ""
|
||||
|
||||
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
|
||||
"""
|
||||
adjusts temperature and repetition_penalty
|
||||
by random values using the base value as a center
|
||||
"""
|
||||
|
||||
temp = prompt_config["temperature"]
|
||||
rep_pen = prompt_config["rep_pen"]
|
||||
|
||||
min_offset = offset * 0.3
|
||||
|
||||
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||
prompt_config["rep_pen"] = random.uniform(
|
||||
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
|
||||
)
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
if "api_key" in kwargs:
|
||||
self.api_key = kwargs.pop("api_key")
|
||||
|
||||
super().reconfigure(**kwargs)
|
||||
|
|
|
@ -11,6 +11,20 @@ class TestPromptPayload(pydantic.BaseModel):
|
|||
kind: str
|
||||
|
||||
|
||||
def ensure_number(v):
|
||||
"""
|
||||
if v is a str but digit turn into into or float
|
||||
"""
|
||||
|
||||
if isinstance(v, str):
|
||||
if v.isdigit():
|
||||
return int(v)
|
||||
try:
|
||||
return float(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
class DevToolsPlugin:
|
||||
router = "devtools"
|
||||
|
||||
|
@ -34,7 +48,7 @@ class DevToolsPlugin:
|
|||
log.info(
|
||||
"Testing prompt",
|
||||
payload={
|
||||
k: v for k, v in payload.generation_parameters.items() if k != "prompt"
|
||||
k: ensure_number(v) for k, v in payload.generation_parameters.items() if k != "prompt"
|
||||
},
|
||||
)
|
||||
|
||||
|
|
4
talemate_frontend/package-lock.json
generated
4
talemate_frontend/package-lock.json
generated
|
@ -1,12 +1,12 @@
|
|||
{
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.25.1",
|
||||
"version": "0.25.2",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.25.1",
|
||||
"version": "0.25.2",
|
||||
"dependencies": {
|
||||
"@codemirror/lang-markdown": "^6.2.5",
|
||||
"@codemirror/theme-one-dark": "^6.1.2",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "talemate_frontend",
|
||||
"version": "0.25.1",
|
||||
"version": "0.25.2",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"serve": "vue-cli-service serve",
|
||||
|
|
|
@ -244,6 +244,13 @@ export default {
|
|||
client.api_key = data.api_key;
|
||||
client.double_coercion = data.data.double_coercion;
|
||||
client.data = data.data;
|
||||
for (let key in client.data.meta.extra_fields) {
|
||||
if (client.data[key] === null || client.data[key] === undefined) {
|
||||
client.data[key] = client.data.meta.defaults[key];
|
||||
}
|
||||
client[key] = client.data[key];
|
||||
}
|
||||
|
||||
} else if(!client) {
|
||||
console.log("Adding new client", data);
|
||||
|
||||
|
@ -259,6 +266,16 @@ export default {
|
|||
double_coercion: data.data.double_coercion,
|
||||
data: data.data,
|
||||
});
|
||||
|
||||
// apply extra field defaults
|
||||
let client = this.state.clients[this.state.clients.length - 1];
|
||||
for (let key in client.data.meta.extra_fields) {
|
||||
if (client.data[key] === null || client.data[key] === undefined) {
|
||||
client.data[key] = client.data.meta.defaults[key];
|
||||
}
|
||||
client[key] = client.data[key];
|
||||
}
|
||||
|
||||
// sort the clients by name
|
||||
this.state.clients.sort((a, b) => (a.name > b.name) ? 1 : -1);
|
||||
}
|
||||
|
|
|
@ -56,9 +56,9 @@
|
|||
</v-row>
|
||||
<v-row v-for="field in clientMeta().extra_fields" :key="field.name">
|
||||
<v-col cols="12">
|
||||
<v-text-field v-model="client.data[field.name]" v-if="field.type === 'text'" :label="field.label"
|
||||
<v-text-field v-model="client[field.name]" v-if="field.type === 'text'" :label="field.label"
|
||||
:rules="[rules.required]" :hint="field.description"></v-text-field>
|
||||
<v-checkbox v-else-if="field.type === 'bool'" v-model="client.data[field.name]"
|
||||
<v-checkbox v-else-if="field.type === 'bool'" v-model="client[field.name]"
|
||||
:label="field.label" :hint="field.description" density="compact"></v-checkbox>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
|
|
@ -248,7 +248,7 @@ export default {
|
|||
messageHandlers: [],
|
||||
scene: {},
|
||||
appConfig: {},
|
||||
autcompleting: false,
|
||||
autocompleting: false,
|
||||
autocompletePartialInput: "",
|
||||
autocompleteCallback: null,
|
||||
autocompleteFocusElement: null,
|
||||
|
|
|
@ -9,6 +9,7 @@ module.exports = defineConfig({
|
|||
},
|
||||
|
||||
devServer: {
|
||||
allowedHosts: "all",
|
||||
client: {
|
||||
overlay: {
|
||||
warnings: false,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue