* fix typo

* fix openai compat config save issue maybe

* fix api_handles_prompt_template no longer saving changes after last fix

* koboldcpp client

* default to kobold ai api

* linting

* conversation cleanup tweak

* 0.25.2

* allowed hosts to all on dev instance

* ensure numbers on parameters when sending edited values

* fix prompt parameter issues

* remove debug message
This commit is contained in:
veguAI 2024-05-10 21:29:29 +03:00 committed by GitHub
parent 60cb271e30
commit a28cf2a029
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 242 additions and 22 deletions

View file

@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
[tool.poetry]
name = "talemate"
version = "0.25.1"
version = "0.25.2"
description = "AI-backed roleplay and narrative tools"
authors = ["FinalWombat"]
license = "GNU Affero General Public License v3.0"

View file

@ -2,4 +2,4 @@ from .agents import Agent
from .client import TextGeneratorWebuiClient
from .tale_mate import *
VERSION = "0.25.1"
VERSION = "0.25.2"

View file

@ -668,7 +668,9 @@ class ConversationAgent(Agent):
total_result = util.handle_endofline_special_delimiter(total_result)
if total_result.startswith(":\n"):
log.info("conversation agent", total_result=total_result)
if total_result.startswith(":\n") or total_result.startswith(": "):
total_result = total_result[2:]
# movie script format

View file

@ -5,9 +5,10 @@ from talemate.client.anthropic import AnthropicClient
from talemate.client.cohere import CohereClient
from talemate.client.google import GoogleClient
from talemate.client.groq import GroqClient
from talemate.client.koboldccp import KoboldCppClient
from talemate.client.lmstudio import LMStudioClient
from talemate.client.mistral import MistralAIClient
from talemate.client.openai import OpenAIClient
from talemate.client.openai_compat import OpenAICompatibleClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.textgenwebui import TextGeneratorWebuiClient
from talemate.client.textgenwebui import TextGeneratorWebuiClient

View file

@ -1,16 +1,201 @@
import asyncio
import json
import logging
import random
from abc import ABC, abstractmethod
from typing import Callable, Union
import re
import requests
# import urljoin
from urllib.parse import urljoin
import httpx
import structlog
import talemate.client.system_prompts as system_prompts
import talemate.util as util
from talemate.client.base import STOPPING_STRINGS, ClientBase, Defaults, ExtraField
from talemate.client.registry import register
from talemate.client.textgenwebui import RESTTaleMateClient
from talemate.emit import Emission, emit
# NOT IMPLEMENTED AT THIS POINT
log = structlog.get_logger("talemate.client.koboldcpp")
class KoboldCppClientDefaults(Defaults):
api_key: str = ""
@register()
class KoboldCppClient(ClientBase):
auto_determine_prompt_template: bool = True
client_type = "koboldcpp"
class Meta(ClientBase.Meta):
name_prefix: str = "KoboldCpp"
title: str = "KoboldCpp"
enable_api_auth: bool = True
defaults: KoboldCppClientDefaults = KoboldCppClientDefaults()
@property
def request_headers(self):
headers = {}
headers["Content-Type"] = "application/json"
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
@property
def is_openai(self) -> bool:
"""
kcpp has two apis
open-ai implementation at /v1
their own implenation at /api/v1
"""
return "/api/v1" not in self.api_url
@property
def api_url_for_model(self) -> str:
if self.is_openai:
# join /model to url
return urljoin(self.api_url, "models")
else:
# join /models to url
return urljoin(self.api_url, "model")
@property
def api_url_for_generation(self) -> str:
if self.is_openai:
# join /v1/completions
return urljoin(self.api_url, "completions")
else:
# join /api/v1/generate
return urljoin(self.api_url, "generate")
def api_endpoint_specified(self, url: str) -> bool:
return "/v1" in self.api_url
def ensure_api_endpoint_specified(self):
if not self.api_endpoint_specified(self.api_url):
# url doesn't specify the api endpoint
# use the koboldcpp openai api
self.api_url = urljoin(self.api_url.rstrip("/") + "/", "/api/v1/")
if not self.api_url.endswith("/"):
self.api_url += "/"
def __init__(self, **kwargs):
self.api_key = kwargs.pop("api_key", "")
super().__init__(**kwargs)
self.ensure_api_endpoint_specified()
def tune_prompt_parameters(self, parameters: dict, kind: str):
super().tune_prompt_parameters(parameters, kind)
if not self.is_openai:
# adjustments for united api
parameters["max_length"] = parameters.pop("max_tokens")
parameters["max_context_length"] = self.max_token_length
if "repetition_penalty_range" in parameters:
parameters["rep_pen_range"] = parameters.pop("repetition_penalty_range")
if "repetition_penalty" in parameters:
parameters["rep_pen"] = parameters.pop("repetition_penalty")
if parameters.get("stop_sequence"):
parameters["stop_sequence"] = parameters.pop("stopping_strings")
if parameters.get("extra_stopping_strings"):
if "stop_sequence" in parameters:
parameters["stop_sequence"] += parameters.pop("extra_stopping_strings")
else:
parameters["stop_sequence"] = parameters.pop("extra_stopping_strings")
allowed_params = [
"max_length",
"max_context_length",
"rep_pen",
"rep_pen_range",
"top_p",
"top_k",
"temperature",
"stop_sequence",
]
else:
# adjustments for openai api
if "repetition_penalty" in parameters:
parameters["presence_penalty"] = parameters.pop(
"repetition_penalty"
)
allowed_params = ["max_tokens", "presence_penalty", "top_p", "temperature"]
# drop unsupported params
for param in list(parameters.keys()):
if param not in allowed_params:
del parameters[param]
def set_client(self, **kwargs):
self.api_key = kwargs.get("api_key", self.api_key)
self.ensure_api_endpoint_specified()
async def get_model_name(self):
self.ensure_api_endpoint_specified()
async with httpx.AsyncClient() as client:
response = await client.get(
self.api_url_for_model,
timeout=2,
headers=self.request_headers,
)
if response.status_code == 404:
raise KeyError(f"Could not find model info at: {self.api_url_for_model}")
response_data = response.json()
if self.is_openai:
# {"object": "list", "data": [{"id": "koboldcpp/dolphin-2.8-mistral-7b", "object": "model", "created": 1, "owned_by": "koboldcpp", "permission": [], "root": "koboldcpp"}]}
model_name = response_data.get("data")[0].get("id")
else:
# {"result": "koboldcpp/dolphin-2.8-mistral-7b"}
model_name = response_data.get("result")
# split by "/" and take last
if model_name:
model_name = model_name.split("/")[-1]
return model_name
async def generate(self, prompt: str, parameters: dict, kind: str):
"""
Generates text from the given prompt and parameters.
"""
parameters["prompt"] = prompt.strip(" ")
async with httpx.AsyncClient() as client:
response = await client.post(
self.api_url_for_generation,
json=parameters,
timeout=None,
headers=self.request_headers,
)
response_data = response.json()
try:
if self.is_openai:
return response_data["choices"][0]["text"]
else:
return response_data["results"][0]["text"]
except (TypeError, KeyError) as exc:
log.error("Failed to generate text", exc=exc, response_data=response_data, response_status=response.status_code)
return ""
def jiggle_randomness(self, prompt_config: dict, offset: float = 0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
rep_pen = prompt_config["rep_pen"]
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
prompt_config["rep_pen"] = random.uniform(
rep_pen + min_offset * 0.3, rep_pen + offset * 0.3
)
def reconfigure(self, **kwargs):
if "api_key" in kwargs:
self.api_key = kwargs.pop("api_key")
super().reconfigure(**kwargs)

View file

@ -11,6 +11,20 @@ class TestPromptPayload(pydantic.BaseModel):
kind: str
def ensure_number(v):
"""
if v is a str but digit turn into into or float
"""
if isinstance(v, str):
if v.isdigit():
return int(v)
try:
return float(v)
except ValueError:
return v
return v
class DevToolsPlugin:
router = "devtools"
@ -34,7 +48,7 @@ class DevToolsPlugin:
log.info(
"Testing prompt",
payload={
k: v for k, v in payload.generation_parameters.items() if k != "prompt"
k: ensure_number(v) for k, v in payload.generation_parameters.items() if k != "prompt"
},
)

View file

@ -1,12 +1,12 @@
{
"name": "talemate_frontend",
"version": "0.25.1",
"version": "0.25.2",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "talemate_frontend",
"version": "0.25.1",
"version": "0.25.2",
"dependencies": {
"@codemirror/lang-markdown": "^6.2.5",
"@codemirror/theme-one-dark": "^6.1.2",

View file

@ -1,6 +1,6 @@
{
"name": "talemate_frontend",
"version": "0.25.1",
"version": "0.25.2",
"private": true,
"scripts": {
"serve": "vue-cli-service serve",

View file

@ -244,6 +244,13 @@ export default {
client.api_key = data.api_key;
client.double_coercion = data.data.double_coercion;
client.data = data.data;
for (let key in client.data.meta.extra_fields) {
if (client.data[key] === null || client.data[key] === undefined) {
client.data[key] = client.data.meta.defaults[key];
}
client[key] = client.data[key];
}
} else if(!client) {
console.log("Adding new client", data);
@ -259,6 +266,16 @@ export default {
double_coercion: data.data.double_coercion,
data: data.data,
});
// apply extra field defaults
let client = this.state.clients[this.state.clients.length - 1];
for (let key in client.data.meta.extra_fields) {
if (client.data[key] === null || client.data[key] === undefined) {
client.data[key] = client.data.meta.defaults[key];
}
client[key] = client.data[key];
}
// sort the clients by name
this.state.clients.sort((a, b) => (a.name > b.name) ? 1 : -1);
}

View file

@ -56,9 +56,9 @@
</v-row>
<v-row v-for="field in clientMeta().extra_fields" :key="field.name">
<v-col cols="12">
<v-text-field v-model="client.data[field.name]" v-if="field.type === 'text'" :label="field.label"
<v-text-field v-model="client[field.name]" v-if="field.type === 'text'" :label="field.label"
:rules="[rules.required]" :hint="field.description"></v-text-field>
<v-checkbox v-else-if="field.type === 'bool'" v-model="client.data[field.name]"
<v-checkbox v-else-if="field.type === 'bool'" v-model="client[field.name]"
:label="field.label" :hint="field.description" density="compact"></v-checkbox>
</v-col>
</v-row>

View file

@ -248,7 +248,7 @@ export default {
messageHandlers: [],
scene: {},
appConfig: {},
autcompleting: false,
autocompleting: false,
autocompletePartialInput: "",
autocompleteCallback: null,
autocompleteFocusElement: null,

View file

@ -9,6 +9,7 @@ module.exports = defineConfig({
},
devServer: {
allowedHosts: "all",
client: {
overlay: {
warnings: false,