Prep 0.13.0 (#28)

* requirements.txt file

* windows installs from requirements.txt because of silly permission issues

* relock

* narrator - narrate on dialogue agent actions

* add support for new textgenwebui api

* world state auto regen trigger off of gameloop

* funciton !rename command

* ensure_dialog_format error handling

* Cat, Nous-Capybara, dolphin-2.2.1

* narrate after dialog rerun fixes, template fixes

* LMStudio client (experimental)

* dolhpin yi

* refactor client base

* cruft

* openai client to new base

* more client refactor fixes

* tweak context retrieval prompts

* adjust nous capybara template

* add Tess-Medium

* 0.13.0

* switch back to poetry for windows as well

* error on legacy textgenwebui api

* runpod text gen api url fixed

* fix windows install script

* add fllow instruction template

* Psyfighter2
This commit is contained in:
FInalWombat 2023-11-18 12:16:29 +02:00 committed by GitHub
parent f9b23f8705
commit d7e72d27c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
37 changed files with 1315 additions and 875 deletions

View file

@ -7,10 +7,10 @@ REM activate the virtual environment
call talemate_env\Scripts\activate
REM install poetry
python -m pip install poetry "rapidfuzz>=3" -U
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
REM use poetry to install dependencies
poetry install
python -m poetry install
REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist
IF NOT EXIST config.yaml copy config.example.yaml config.yaml

383
poetry.lock generated
View file

@ -344,17 +344,17 @@ files = [
[[package]]
name = "boto3"
version = "1.28.83"
version = "1.28.84"
description = "The AWS SDK for Python"
optional = false
python-versions = ">= 3.7"
files = [
{file = "boto3-1.28.83-py3-none-any.whl", hash = "sha256:1d10691911c4b8b9443d3060257ba32b68b6e3cad0eebbb9f69fd1c52a78417f"},
{file = "boto3-1.28.83.tar.gz", hash = "sha256:489c4967805b677b7a4030460e4c06c0903d6bc0f6834453611bf87efbd8d8a3"},
{file = "boto3-1.28.84-py3-none-any.whl", hash = "sha256:98b01bbea27740720a06f7c7bc0132ae4ce902e640aab090cfb99ad3278449c3"},
{file = "boto3-1.28.84.tar.gz", hash = "sha256:adfb915958d7b54d876891ea1599dd83189e35a2442eb41ca52b04ea716180b6"},
]
[package.dependencies]
botocore = ">=1.31.83,<1.32.0"
botocore = ">=1.31.84,<1.32.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.7.0,<0.8.0"
@ -363,13 +363,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.31.83"
version = "1.31.84"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">= 3.7"
files = [
{file = "botocore-1.31.83-py3-none-any.whl", hash = "sha256:c742069e8bfd06d212d712228258ff09fb481b6ec02358e539381ce0fcad065a"},
{file = "botocore-1.31.83.tar.gz", hash = "sha256:40914b0fb28f13d709e1f8a4481e278350b77a3987be81acd23715ec8d5fedca"},
{file = "botocore-1.31.84-py3-none-any.whl", hash = "sha256:d65bc05793d1a8a8c191a739f742876b4b403c5c713dc76beef262d18f7984a2"},
{file = "botocore-1.31.84.tar.gz", hash = "sha256:8913bedb96ad0427660dee083aeaa675466eb662bbf1a47781956b5882aadcc5"},
]
[package.dependencies]
@ -378,7 +378,18 @@ python-dateutil = ">=2.1,<3.0.0"
urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}
[package.extras]
crt = ["awscrt (==0.16.26)"]
crt = ["awscrt (==0.19.10)"]
[[package]]
name = "cachetools"
version = "5.3.2"
description = "Extensible memoizing collections and decorators"
optional = false
python-versions = ">=3.7"
files = [
{file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"},
{file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"},
]
[[package]]
name = "certifi"
@ -529,13 +540,13 @@ numpy = "*"
[[package]]
name = "chromadb"
version = "0.4.14"
version = "0.4.17"
description = "Chroma."
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "chromadb-0.4.14-py3-none-any.whl", hash = "sha256:c1b59bdfb4b35a40bad0b8927c5ed757adf191ff9db2b9a384dc46a76e1ff10f"},
{file = "chromadb-0.4.14.tar.gz", hash = "sha256:0fcef603bcf9c854305020c3f8d368c09b1545d48bd2bceefd51861090f87dad"},
{file = "chromadb-0.4.17-py3-none-any.whl", hash = "sha256:8cb88162bc6124441ba5a4b93819463a10e9aaafbe05a3286e876cbdc7a7e11d"},
{file = "chromadb-0.4.17.tar.gz", hash = "sha256:120f9d364719b664d5314500f8e6097f0e0b24496bb97a429bc324f8d11f1b52"},
]
[package.dependencies]
@ -544,14 +555,20 @@ chroma-hnswlib = "0.7.3"
fastapi = ">=0.95.2"
grpcio = ">=1.58.0"
importlib-resources = "*"
kubernetes = ">=28.1.0"
numpy = {version = ">=1.22.5", markers = "python_version >= \"3.8\""}
onnxruntime = ">=1.14.1"
opentelemetry-api = ">=1.2.0"
opentelemetry-exporter-otlp-proto-grpc = ">=1.2.0"
opentelemetry-sdk = ">=1.2.0"
overrides = ">=7.3.1"
posthog = ">=2.4.0"
pulsar-client = ">=3.1.0"
pydantic = ">=1.9"
pypika = ">=0.48.9"
PyYAML = ">=6.0.0"
requests = ">=2.28"
tenacity = ">=8.2.3"
tokenizers = ">=0.13.2"
tqdm = ">=4.65.0"
typer = ">=0.9.0"
@ -600,6 +617,23 @@ humanfriendly = ">=9.1"
[package.extras]
cron = ["capturer (>=2.4)"]
[[package]]
name = "deprecated"
version = "1.2.14"
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"},
{file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"},
]
[package.dependencies]
wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]]
name = "distro"
version = "1.8.0"
@ -822,6 +856,46 @@ smb = ["smbprotocol"]
ssh = ["paramiko"]
tqdm = ["tqdm"]
[[package]]
name = "google-auth"
version = "2.23.4"
description = "Google Authentication Library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google-auth-2.23.4.tar.gz", hash = "sha256:79905d6b1652187def79d491d6e23d0cbb3a21d3c7ba0dbaa9c8a01906b13ff3"},
{file = "google_auth-2.23.4-py2.py3-none-any.whl", hash = "sha256:d4bbc92fe4b8bfd2f3e8d88e5ba7085935da208ee38a134fc280e7ce682a05f2"},
]
[package.dependencies]
cachetools = ">=2.0.0,<6.0"
pyasn1-modules = ">=0.2.1"
rsa = ">=3.1.4,<5"
[package.extras]
aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"]
enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
reauth = ["pyu2f (>=0.1.5)"]
requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
[[package]]
name = "googleapis-common-protos"
version = "1.61.0"
description = "Common protobufs used in Google APIs"
optional = false
python-versions = ">=3.7"
files = [
{file = "googleapis-common-protos-1.61.0.tar.gz", hash = "sha256:8a64866a97f6304a7179873a465d6eee97b7a24ec6cfd78e0f575e96b821240b"},
{file = "googleapis_common_protos-1.61.0-py2.py3-none-any.whl", hash = "sha256:22f1915393bb3245343f6efe87f6fe868532efc12aa26b391b15132e1279f1c0"},
]
[package.dependencies]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
[package.extras]
grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
[[package]]
name = "grpcio"
version = "1.59.2"
@ -1050,6 +1124,25 @@ files = [
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
]
[[package]]
name = "importlib-metadata"
version = "6.8.0"
description = "Read metadata from Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"},
{file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"},
]
[package.dependencies]
zipp = ">=0.5"
[package.extras]
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
perf = ["ipython"]
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
[[package]]
name = "importlib-resources"
version = "6.1.1"
@ -1187,6 +1280,32 @@ files = [
{file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
]
[[package]]
name = "kubernetes"
version = "28.1.0"
description = "Kubernetes python client"
optional = false
python-versions = ">=3.6"
files = [
{file = "kubernetes-28.1.0-py2.py3-none-any.whl", hash = "sha256:10f56f8160dcb73647f15fafda268e7f60cf7dbc9f8e46d52fcd46d3beb0c18d"},
{file = "kubernetes-28.1.0.tar.gz", hash = "sha256:1468069a573430fb1cb5ad22876868f57977930f80a6749405da31cd6086a7e9"},
]
[package.dependencies]
certifi = ">=14.05.14"
google-auth = ">=1.0.1"
oauthlib = ">=3.2.2"
python-dateutil = ">=2.5.3"
pyyaml = ">=5.4.1"
requests = "*"
requests-oauthlib = "*"
six = ">=1.9.0"
urllib3 = ">=1.24.2,<2.0"
websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0"
[package.extras]
adal = ["adal (>=1.0.2)"]
[[package]]
name = "lazy-object-proxy"
version = "1.9.0"
@ -1259,16 +1378,6 @@ files = [
{file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
{file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
{file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"},
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"},
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"},
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"},
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"},
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"},
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"},
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"},
{file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"},
{file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"},
{file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
@ -1551,6 +1660,22 @@ files = [
{file = "numpy-1.25.2.tar.gz", hash = "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760"},
]
[[package]]
name = "oauthlib"
version = "3.2.2"
description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
optional = false
python-versions = ">=3.6"
files = [
{file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"},
{file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"},
]
[package.extras]
rsa = ["cryptography (>=3.0.0)"]
signals = ["blinker (>=1.4.0)"]
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
[[package]]
name = "onnxruntime"
version = "1.16.2"
@ -1614,6 +1739,101 @@ typing-extensions = ">=4.5,<5"
[package.extras]
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
[[package]]
name = "opentelemetry-api"
version = "1.21.0"
description = "OpenTelemetry Python API"
optional = false
python-versions = ">=3.7"
files = [
{file = "opentelemetry_api-1.21.0-py3-none-any.whl", hash = "sha256:4bb86b28627b7e41098f0e93280fe4892a1abed1b79a19aec6f928f39b17dffb"},
{file = "opentelemetry_api-1.21.0.tar.gz", hash = "sha256:d6185fd5043e000075d921822fd2d26b953eba8ca21b1e2fa360dd46a7686316"},
]
[package.dependencies]
deprecated = ">=1.2.6"
importlib-metadata = ">=6.0,<7.0"
[[package]]
name = "opentelemetry-exporter-otlp-proto-common"
version = "1.21.0"
description = "OpenTelemetry Protobuf encoding"
optional = false
python-versions = ">=3.7"
files = [
{file = "opentelemetry_exporter_otlp_proto_common-1.21.0-py3-none-any.whl", hash = "sha256:97b1022b38270ec65d11fbfa348e0cd49d12006485c2321ea3b1b7037d42b6ec"},
{file = "opentelemetry_exporter_otlp_proto_common-1.21.0.tar.gz", hash = "sha256:61db274d8a68d636fb2ec2a0f281922949361cdd8236e25ff5539edf942b3226"},
]
[package.dependencies]
backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""}
opentelemetry-proto = "1.21.0"
[[package]]
name = "opentelemetry-exporter-otlp-proto-grpc"
version = "1.21.0"
description = "OpenTelemetry Collector Protobuf over gRPC Exporter"
optional = false
python-versions = ">=3.7"
files = [
{file = "opentelemetry_exporter_otlp_proto_grpc-1.21.0-py3-none-any.whl", hash = "sha256:ab37c63d6cb58d6506f76d71d07018eb1f561d83e642a8f5aa53dddf306087a4"},
{file = "opentelemetry_exporter_otlp_proto_grpc-1.21.0.tar.gz", hash = "sha256:a497c5611245a2d17d9aa1e1cbb7ab567843d53231dcc844a62cea9f0924ffa7"},
]
[package.dependencies]
backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""}
deprecated = ">=1.2.6"
googleapis-common-protos = ">=1.52,<2.0"
grpcio = ">=1.0.0,<2.0.0"
opentelemetry-api = ">=1.15,<2.0"
opentelemetry-exporter-otlp-proto-common = "1.21.0"
opentelemetry-proto = "1.21.0"
opentelemetry-sdk = ">=1.21.0,<1.22.0"
[package.extras]
test = ["pytest-grpc"]
[[package]]
name = "opentelemetry-proto"
version = "1.21.0"
description = "OpenTelemetry Python Proto"
optional = false
python-versions = ">=3.7"
files = [
{file = "opentelemetry_proto-1.21.0-py3-none-any.whl", hash = "sha256:32fc4248e83eebd80994e13963e683f25f3b443226336bb12b5b6d53638f50ba"},
{file = "opentelemetry_proto-1.21.0.tar.gz", hash = "sha256:7d5172c29ed1b525b5ecf4ebe758c7138a9224441b3cfe683d0a237c33b1941f"},
]
[package.dependencies]
protobuf = ">=3.19,<5.0"
[[package]]
name = "opentelemetry-sdk"
version = "1.21.0"
description = "OpenTelemetry Python SDK"
optional = false
python-versions = ">=3.7"
files = [
{file = "opentelemetry_sdk-1.21.0-py3-none-any.whl", hash = "sha256:9fe633243a8c655fedace3a0b89ccdfc654c0290ea2d8e839bd5db3131186f73"},
{file = "opentelemetry_sdk-1.21.0.tar.gz", hash = "sha256:3ec8cd3020328d6bc5c9991ccaf9ae820ccb6395a5648d9a95d3ec88275b8879"},
]
[package.dependencies]
opentelemetry-api = "1.21.0"
opentelemetry-semantic-conventions = "0.42b0"
typing-extensions = ">=3.7.4"
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.42b0"
description = "OpenTelemetry Semantic Conventions"
optional = false
python-versions = ">=3.7"
files = [
{file = "opentelemetry_semantic_conventions-0.42b0-py3-none-any.whl", hash = "sha256:5cd719cbfec448af658860796c5d0fcea2fdf0945a2bed2363f42cb1ee39f526"},
{file = "opentelemetry_semantic_conventions-0.42b0.tar.gz", hash = "sha256:44ae67a0a3252a05072877857e5cc1242c98d4cf12870159f1a94bec800d38ec"},
]
[[package]]
name = "orjson"
version = "3.9.10"
@ -1953,6 +2173,31 @@ files = [
{file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"},
]
[[package]]
name = "pyasn1"
version = "0.5.0"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
files = [
{file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"},
{file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"},
]
[[package]]
name = "pyasn1-modules"
version = "0.3.0"
description = "A collection of ASN.1-based protocols modules"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
files = [
{file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"},
{file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"},
]
[package.dependencies]
pyasn1 = ">=0.4.6,<0.6.0"
[[package]]
name = "pydantic"
version = "2.4.2"
@ -2488,6 +2733,24 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "requests-oauthlib"
version = "1.3.1"
description = "OAuthlib authentication support for Requests."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
]
[package.dependencies]
oauthlib = ">=3.0.0"
requests = ">=2.0.0"
[package.extras]
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
[[package]]
name = "rope"
version = "0.22.0"
@ -2502,6 +2765,20 @@ files = [
[package.extras]
dev = ["build", "pytest", "pytest-timeout"]
[[package]]
name = "rsa"
version = "4.9"
description = "Pure-Python RSA implementation"
optional = false
python-versions = ">=3.6,<4"
files = [
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
]
[package.dependencies]
pyasn1 = ">=0.1.3"
[[package]]
name = "runpod"
version = "1.2.0"
@ -2909,6 +3186,20 @@ files = [
[package.dependencies]
mpmath = ">=0.19"
[[package]]
name = "tenacity"
version = "8.2.3"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.7"
files = [
{file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"},
{file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"},
]
[package.extras]
doc = ["reno", "sphinx", "tornado (>=4.5)"]
[[package]]
name = "thefuzz"
version = "0.20.0"
@ -3415,20 +3706,19 @@ files = [
[[package]]
name = "urllib3"
version = "2.0.7"
version = "1.26.18"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = ">=3.7"
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
files = [
{file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"},
{file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"},
{file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"},
{file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"},
]
[package.extras]
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]]
name = "uvicorn"
@ -3587,6 +3877,22 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
[[package]]
name = "websocket-client"
version = "1.6.4"
description = "WebSocket client for Python with low level API options"
optional = false
python-versions = ">=3.8"
files = [
{file = "websocket-client-1.6.4.tar.gz", hash = "sha256:b3324019b3c28572086c4a319f91d1dcd44e6e11cd340232978c684a7650d0df"},
{file = "websocket_client-1.6.4-py3-none-any.whl", hash = "sha256:084072e0a7f5f347ef2ac3d8698a5e0b4ffbfcab607628cadabc650fc9a83a24"},
]
[package.extras]
docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"]
optional = ["python-socks", "wsaccel"]
test = ["websockets"]
[[package]]
name = "websockets"
version = "11.0.3"
@ -3832,7 +4138,22 @@ files = [
idna = ">=2.0"
multidict = ">=4.0"
[[package]]
name = "zipp"
version = "3.17.0"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
files = [
{file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
{file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
content-hash = "13dc0c939ece1591caa09211c5a29a839cb63b5a921797ab225fc723b66e0d67"
content-hash = "8d77eeb6bba3c389345f461840b5257716a397e3ecaebc735a26b06e27361a1a"

View file

@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
[tool.poetry]
name = "talemate"
version = "0.12.0"
version = "0.13.0"
description = "AI-backed roleplay and narrative tools"
authors = ["FinalWombat"]
license = "GNU Affero General Public License v3.0"
@ -39,9 +39,9 @@ thefuzz = ">=0.20.0"
tiktoken = ">=0.5.1"
# ChromaDB
chromadb = ">=0.4,<1"
chromadb = ">=0.4.17,<1"
InstructorEmbedding = "^1.0.1"
torch = ">=2.0.0, !=2.0.1"
torch = ">=2.1.0"
sentence-transformers="^2.2.2"
[tool.poetry.dev-dependencies]

View file

@ -9,7 +9,7 @@ REM activate the virtual environment
call talemate_env\Scripts\activate
REM install poetry
python -m pip install poetry "rapidfuzz>=3" -U
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
REM use poetry to install dependencies
python -m poetry install

View file

@ -2,4 +2,4 @@ from .agents import Agent
from .client import TextGeneratorWebuiClient
from .tale_mate import *
VERSION = "0.12.0"
VERSION = "0.13.0"

View file

@ -328,9 +328,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
model_name=instructor_model, device=instructor_device
)
log.info("chromadb", status="embedding function ready")
self.db = self.db_client.get_or_create_collection(
collection_name, embedding_function=ef
)
log.info("chromadb", status="instructor db ready")
else:
log.info("chromadb", status="using default embeddings")
self.db = self.db_client.get_or_create_collection(collection_name)
@ -461,6 +465,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
#import json
#print(json.dumps(_results["ids"], indent=2))
#print(json.dumps(_results["distances"], indent=2))
results = []

View file

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import structlog
import random
import talemate.util as util
from talemate.emit import emit
import talemate.emit.async_signals
@ -9,14 +10,23 @@ from talemate.prompts import Prompt
from talemate.agents.base import set_processing, Agent, AgentAction, AgentActionConfig
from talemate.agents.world_state import TimePassageEmission
from talemate.scene_message import NarratorMessage
from talemate.events import GameLoopActorIterEvent
import talemate.client as client
from .registry import register
if TYPE_CHECKING:
from talemate.tale_mate import Actor, Player, Character
log = structlog.get_logger("talemate.agents.narrator")
@register()
class NarratorAgent(Agent):
"""
Handles narration of the story
"""
agent_type = "narrator"
verbose_name = "Narrator"
@ -27,31 +37,78 @@ class NarratorAgent(Agent):
):
self.client = client
# agent actions
self.actions = {
"narrate_time_passage": AgentAction(enabled=False, label="Narrate Time Passage", description="Whenever you indicate passage of time, narrate right after"),
"narrate_time_passage": AgentAction(enabled=True, label="Narrate Time Passage", description="Whenever you indicate passage of time, narrate right after"),
"narrate_dialogue": AgentAction(
enabled=True,
label="Narrate Dialogue",
description="Narrator will get a chance to narrate after every line of dialogue",
config = {
"ai_dialog": AgentActionConfig(
type="number",
label="AI Dialogue",
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
value=0.3,
min=0.0,
max=1.0,
step=0.1,
),
"player_dialog": AgentActionConfig(
type="number",
label="Player Dialogue",
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
value=0.3,
min=0.0,
max=1.0,
step=0.1,
),
}
),
}
def clean_result(self, result):
"""
Cleans the result of a narration
"""
result = result.strip().strip(":").strip()
if "#" in result:
result = result.split("#")[0]
character_names = [c.name for c in self.scene.get_characters()]
cleaned = []
for line in result.split("\n"):
if ":" in line.strip():
break
for character_name in character_names:
if line.startswith(f"{character_name}:"):
break
cleaned.append(line)
return "\n".join(cleaned)
result = "\n".join(cleaned)
#result = util.strip_partial_sentences(result)
return result
def connect(self, scene):
"""
Connect to signals
"""
super().connect(scene)
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_dialog)
async def on_time_passage(self, event:TimePassageEmission):
"""
Handles time passage narration, if enabled
"""
if not self.actions["narrate_time_passage"].enabled:
return
@ -60,6 +117,31 @@ class NarratorAgent(Agent):
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
async def on_dialog(self, event:GameLoopActorIterEvent):
"""
Handles dialogue narration, if enabled
"""
if not self.actions["narrate_dialogue"].enabled:
return
narrate_on_ai_chance = random.random() < self.actions["narrate_dialogue"].config["ai_dialog"].value
narrate_on_player_chance = random.random() < self.actions["narrate_dialogue"].config["player_dialog"].value
log.debug("narrate on dialog", narrate_on_ai_chance=narrate_on_ai_chance, narrate_on_player_chance=narrate_on_player_chance)
if event.actor.character.is_player and not narrate_on_player_chance:
return
if not event.actor.character.is_player and not narrate_on_ai_chance:
return
response = await self.narrate_after_dialogue(event.actor.character)
narrator_message = NarratorMessage(response, source=f"narrate_dialogue:{event.actor.character.name}")
emit("narrator", narrator_message)
self.scene.push_history(narrator_message)
@set_processing
async def narrate_scene(self):
"""
@ -155,8 +237,9 @@ class NarratorAgent(Agent):
"as_narrative": as_narrative,
}
)
log.info("narrate_query", response=response)
response = self.clean_result(response.strip())
log.info("narrate_query (after clean)", response=response)
if as_narrative:
response = f"*{response}*"
@ -266,3 +349,29 @@ class NarratorAgent(Agent):
response = f"*{response}*"
return response
@set_processing
async def narrate_after_dialogue(self, character:Character):
"""
Narrate after a line of dialogue
"""
response = await Prompt.request(
"narrator.narrate-after-dialogue",
self.client,
"narrate",
vars = {
"scene": self.scene,
"max_tokens": self.client.max_token_length,
"character": character,
"last_line": str(self.scene.history[-1])
}
)
log.info("narrate_after_dialogue", response=response)
response = self.clean_result(response.strip().strip("*"))
response = f"*{response}*"
return response

View file

@ -8,6 +8,7 @@ import talemate.util as util
from talemate.prompts import Prompt
from talemate.scene_message import DirectorMessage, TimePassageMessage
from talemate.emit import emit
from talemate.events import GameLoopEvent
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
from .registry import register
@ -16,9 +17,6 @@ import structlog
import isodate
import time
if TYPE_CHECKING:
from talemate.agents.conversation import ConversationAgentEmission
log = structlog.get_logger("talemate.agents.world_state")
@ -74,7 +72,7 @@ class WorldStateAgent(Agent):
def connect(self, scene):
super().connect(scene)
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
async def advance_time(self, duration:str, narrative:str=None):
"""
@ -96,7 +94,7 @@ class WorldStateAgent(Agent):
)
async def on_conversation_generated(self, emission:ConversationAgentEmission):
async def on_game_loop(self, emission:GameLoopEvent):
"""
Called when a conversation is generated
"""
@ -104,8 +102,7 @@ class WorldStateAgent(Agent):
if not self.enabled:
return
for _ in emission.generation:
await self.update_world_state()
await self.update_world_state()
async def update_world_state(self):
@ -230,7 +227,7 @@ class WorldStateAgent(Agent):
):
response = await Prompt.request(
"world_state.analyze-and-follow-instruction",
"world_state.analyze-text-and-follow-instruction",
self.client,
"analyze_freeform",
vars = {

View file

@ -1,4 +1,6 @@
import os
from talemate.client.openai import OpenAIClient
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
from talemate.client.textgenwebui import TextGeneratorWebuiClient
from talemate.client.lmstudio import LMStudioClient
import talemate.client.runpod

349
src/talemate/client/base.py Normal file
View file

@ -0,0 +1,349 @@
"""
A unified client base, based on the openai API
"""
import copy
import random
import time
from typing import Callable
import structlog
import logging
from openai import AsyncOpenAI
from talemate.emit import emit
import talemate.instance as instance
import talemate.client.presets as presets
import talemate.client.system_prompts as system_prompts
import talemate.util as util
from talemate.client.context import client_context_attribute
from talemate.client.model_prompts import model_prompt
# Set up logging level for httpx to WARNING to suppress debug logs.
logging.getLogger('httpx').setLevel(logging.WARNING)
REMOTE_SERVICES = [
# TODO: runpod.py should add this to the list
".runpod.net"
]
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
class ClientBase:
api_url: str
model_name: str
name:str = None
enabled: bool = True
current_status: str = None
max_token_length: int = 4096
randomizable_inference_parameters: list[str] = ["temperature"]
processing: bool = False
connected: bool = False
conversation_retries: int = 5
client_type = "base"
def __init__(
self,
api_url: str,
name = None,
**kwargs,
):
self.api_url = api_url
self.name = name or self.client_type
self.log = structlog.get_logger(f"client.{self.client_type}")
self.set_client()
def __str__(self):
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
def set_client(self):
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
def prompt_template(self, sys_msg, prompt):
"""
Applies the appropriate prompt template for the model.
"""
if not self.model_name:
self.log.warning("prompt template not applied", reason="no model loaded")
return f"{sys_msg}\n{prompt}"
return model_prompt(self.model_name, sys_msg, prompt)
def reconfigure(self, **kwargs):
"""
Reconfigures the client.
Keyword Arguments:
- api_url: the API URL to use
- max_token_length: the max token length to use
- enabled: whether the client is enabled
"""
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
self.max_token_length = kwargs["max_token_length"]
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
def toggle_disabled_if_remote(self):
"""
If the client is targeting a remote recognized service, this
will disable the client.
"""
for service in REMOTE_SERVICES:
if service in self.api_url:
if self.enabled:
self.log.warn("remote service unreachable, disabling client", client=self.name)
self.enabled = False
return True
return False
def get_system_message(self, kind: str) -> str:
"""
Returns the appropriate system message for the given kind of generation
Arguments:
- kind: the kind of generation
"""
# TODO: make extensible
if "narrate" in kind:
return system_prompts.NARRATOR
if "story" in kind:
return system_prompts.NARRATOR
if "director" in kind:
return system_prompts.DIRECTOR
if "create" in kind:
return system_prompts.CREATOR
if "roleplay" in kind:
return system_prompts.ROLEPLAY
if "conversation" in kind:
return system_prompts.ROLEPLAY
if "editor" in kind:
return system_prompts.EDITOR
if "world_state" in kind:
return system_prompts.WORLD_STATE
if "analyst" in kind:
return system_prompts.ANALYST
if "analyze" in kind:
return system_prompts.ANALYST
return system_prompts.BASIC
def emit_status(self, processing: bool = None):
"""
Sets and emits the client status.
"""
if processing is not None:
self.processing = processing
if not self.enabled:
status = "disabled"
model_name = "Disabled"
elif not self.connected:
status = "error"
model_name = "Could not connect"
elif self.model_name:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
model_name = "No model loaded"
status = "warning"
status_change = status != self.current_status
self.current_status = status
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
)
if status_change:
instance.emit_agent_status_by_client(self)
async def get_model_name(self):
models = await self.client.models.list()
try:
return models.data[0].id
except IndexError:
return None
async def status(self):
"""
Send a request to the API to retrieve the loaded AI model name.
Raises an error if no model name is returned.
:return: None
"""
if self.processing:
return
if not self.enabled:
self.connected = False
self.emit_status()
return
try:
self.model_name = await self.get_model_name()
except Exception as e:
self.log.warning("client status error", e=e, client=self.name)
self.model_name = None
self.connected = False
self.toggle_disabled_if_remote()
self.emit_status()
return
self.connected = True
if not self.model_name or self.model_name == "None":
self.log.warning("client model not loaded", client=self)
self.emit_status()
return
self.emit_status()
def generate_prompt_parameters(self, kind:str):
parameters = {}
self.tune_prompt_parameters(
presets.configure(parameters, kind, self.max_token_length),
kind
)
return parameters
def tune_prompt_parameters(self, parameters:dict, kind:str):
parameters["stream"] = False
if client_context_attribute("nuke_repetition") > 0.0 and self.jiggle_enabled_for(kind):
self.jiggle_randomness(parameters, offset=client_context_attribute("nuke_repetition"))
fn_tune_kind = getattr(self, f"tune_prompt_parameters_{kind}", None)
if fn_tune_kind:
fn_tune_kind(parameters)
def tune_prompt_parameters_conversation(self, parameters:dict):
conversation_context = client_context_attribute("conversation")
parameters["max_tokens"] = conversation_context.get("length", 96)
dialog_stopping_strings = [
f"{character}:" for character in conversation_context["other_characters"]
]
if "extra_stopping_strings" in parameters:
parameters["extra_stopping_strings"] += dialog_stopping_strings
else:
parameters["extra_stopping_strings"] = dialog_stopping_strings
async def generate(self, prompt:str, parameters:dict, kind:str):
"""
Generates text from the given prompt and parameters.
"""
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
try:
response = await self.client.completions.create(prompt=prompt.strip(), **parameters)
return response.get("choices", [{}])[0].get("text", "")
except Exception as e:
self.log.error("generate error", e=e)
return ""
async def send_prompt(
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
) -> str:
"""
Send a prompt to the AI and return its response.
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
try:
self.emit_status(processing=True)
await self.status()
prompt_param = self.generate_prompt_parameters(kind)
finalized_prompt = self.prompt_template(self.get_system_message(kind), prompt).strip()
prompt_param = finalize(prompt_param)
token_length = self.count_tokens(finalized_prompt)
time_start = time.time()
extra_stopping_strings = prompt_param.pop("extra_stopping_strings", [])
self.log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length, parameters=prompt_param)
response = await self.generate(finalized_prompt, prompt_param, kind)
time_end = time.time()
# stopping strings sometimes get appended to the end of the response anyways
# split the response by the first stopping string and take the first part
for stopping_string in STOPPING_STRINGS + extra_stopping_strings:
if stopping_string in response:
response = response.split(stopping_string)[0]
break
emit("prompt_sent", data={
"kind": kind,
"prompt": finalized_prompt,
"response": response,
"prompt_tokens": token_length,
"response_tokens": self.count_tokens(response),
"time": time_end - time_start,
})
return response
finally:
self.emit_status(processing=False)
def count_tokens(self, content:str):
return util.count_tokens(content)
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
min_offset = offset * 0.3
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
def jiggle_enabled_for(self, kind:str):
if kind in ["conversation", "story"]:
return True
if kind.startswith("narrate"):
return True
return False

View file

@ -0,0 +1,56 @@
from talemate.client.base import ClientBase
from talemate.client.registry import register
from openai import AsyncOpenAI
@register()
class LMStudioClient(ClientBase):
client_type = "lmstudio"
conversation_retries = 5
def set_client(self):
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
def tune_prompt_parameters(self, parameters:dict, kind:str):
super().tune_prompt_parameters(parameters, kind)
keys = list(parameters.keys())
valid_keys = ["temperature", "top_p"]
for key in keys:
if key not in valid_keys:
del parameters[key]
async def get_model_name(self):
model_name = await super().get_model_name()
# model name comes back as a file path, so we need to extract the model name
# the path could be windows or linux so it needs to handle both backslash and forward slash
if model_name:
model_name = model_name.replace("\\", "/").split("/")[-1]
return model_name
async def generate(self, prompt:str, parameters:dict, kind:str):
"""
Generates text from the given prompt and parameters.
"""
human_message = {'role': 'user', 'content': prompt.strip()}
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[human_message], **parameters
)
return response.choices[0].message.content
except Exception as e:
self.log.error("generate error", e=e)
return ""

View file

@ -1,10 +1,9 @@
import asyncio
import os
import time
from typing import Callable
import json
from openai import AsyncOpenAI
from talemate.client.base import ClientBase
from talemate.client.registry import register
from talemate.emit import emit
from talemate.config import load_config
@ -15,10 +14,9 @@ import tiktoken
__all__ = [
"OpenAIClient",
]
log = structlog.get_logger("talemate")
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
@ -70,7 +68,7 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
return num_tokens
@register()
class OpenAIClient:
class OpenAIClient(ClientBase):
"""
OpenAI client for generating text.
"""
@ -79,13 +77,10 @@ class OpenAIClient:
conversation_retries = 0
def __init__(self, model="gpt-4-1106-preview", **kwargs):
self.name = kwargs.get("name", "openai")
self.model_name = model
self.last_token_length = 0
self.max_token_length = 2048
self.processing = False
self.current_status = "idle"
self.config = load_config()
super().__init__(**kwargs)
# if os.environ.get("OPENAI_API_KEY") is not set, look in the config file
# and set it
@ -94,7 +89,7 @@ class OpenAIClient:
if self.config.get("openai", {}).get("api_key"):
os.environ["OPENAI_API_KEY"] = self.config["openai"]["api_key"]
self.set_client(model)
self.set_client()
@property
@ -123,12 +118,14 @@ class OpenAIClient:
status=status,
)
def set_client(self, model:str, max_token_length:int=None):
def set_client(self, max_token_length:int=None):
if not self.openai_api_key:
log.error("No OpenAI API key set")
return
model = self.model_name
self.client = AsyncOpenAI()
if model == "gpt-3.5-turbo":
self.max_token_length = min(max_token_length or 4096, 4096)
@ -144,89 +141,72 @@ class OpenAIClient:
def reconfigure(self, **kwargs):
if "model" in kwargs:
self.model_name = kwargs["model"]
self.set_client(self.model_name, kwargs.get("max_token_length"))
self.set_client(kwargs.get("max_token_length"))
def count_tokens(self, content: str):
return num_tokens_from_messages([{"content": content}], model=self.model_name)
async def status(self):
self.emit_status()
def get_system_message(self, kind: str) -> str:
if "narrate" in kind:
return system_prompts.NARRATOR
if "story" in kind:
return system_prompts.NARRATOR
if "director" in kind:
return system_prompts.DIRECTOR
if "create" in kind:
return system_prompts.CREATOR
if "roleplay" in kind:
return system_prompts.ROLEPLAY
if "conversation" in kind:
return system_prompts.ROLEPLAY
if "editor" in kind:
return system_prompts.EDITOR
if "world_state" in kind:
return system_prompts.WORLD_STATE
if "analyst" in kind:
return system_prompts.ANALYST
if "analyze" in kind:
return system_prompts.ANALYST
return system_prompts.BASIC
async def send_prompt(
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
) -> str:
right = ""
opts = {}
def prompt_template(self, system_message:str, prompt:str):
# only gpt-4-1106-preview supports json_object response coersion
supports_json_object = self.model_name in ["gpt-4-1106-preview"]
if "<|BOT|>" in prompt:
_, right = prompt.split("<|BOT|>", 1)
if right:
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
expected_response = prompt.split("\nContinue this response: ")[1].strip()
if expected_response.startswith("{") and supports_json_object:
opts["response_format"] = {"type": "json_object"}
else:
prompt = prompt.replace("<|BOT|>", "")
self.emit_status(processing=True)
await asyncio.sleep(0.1)
return prompt
sys_message = {'role': 'system', 'content': self.get_system_message(kind)}
def tune_prompt_parameters(self, parameters:dict, kind:str):
super().tune_prompt_parameters(parameters, kind)
human_message = {'role': 'user', 'content': prompt}
keys = list(parameters.keys())
log.debug("openai send", kind=kind, sys_message=sys_message, opts=opts)
valid_keys = ["temperature", "top_p"]
time_start = time.time()
for key in keys:
if key not in valid_keys:
del parameters[key]
response = await self.client.chat.completions.create(model=self.model_name, messages=[sys_message, human_message], **opts)
async def generate(self, prompt:str, parameters:dict, kind:str):
time_end = time.time()
"""
Generates text from the given prompt and parameters.
"""
response = response.choices[0].message.content
# only gpt-4-1106-preview supports json_object response coersion
supports_json_object = self.model_name in ["gpt-4-1106-preview"]
right = None
try:
_, right = prompt.split("\nContinue this response: ")
expected_response = right.strip()
if expected_response.startswith("{") and supports_json_object:
parameters["response_format"] = {"type": "json_object"}
except IndexError:
pass
if right and response.startswith(right):
response = response[len(right):].strip()
human_message = {'role': 'user', 'content': prompt.strip()}
system_message = {'role': 'system', 'content': self.get_system_message(kind)}
if kind == "conversation":
response = response.replace("\n", " ").strip()
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
log.debug("openai response", response=response)
try:
response = await self.client.chat.completions.create(
model=self.model_name, messages=[system_message, human_message], **parameters
)
emit("prompt_sent", data={
"kind": kind,
"prompt": prompt,
"response": response,
"prompt_tokens": num_tokens_from_messages([sys_message, human_message], model=self.model_name),
"response_tokens": num_tokens_from_messages([{"role": "assistant", "content": response}], model=self.model_name),
"time": time_end - time_start,
})
response = response.choices[0].message.content
self.emit_status(processing=False)
return response
if right and response.startswith(right):
response = response[len(right):].strip()
return response
except Exception as e:
self.log.error("generate error", e=e)
return ""

View file

@ -0,0 +1,163 @@
__all__ = [
"configure",
"set_max_tokens",
"set_preset",
"preset_for_kind",
"max_tokens_for_kind",
"PRESET_TALEMATE_CONVERSATION",
"PRESET_TALEMATE_CREATOR",
"PRESET_LLAMA_PRECISE",
"PRESET_DIVINE_INTELLECT",
"PRESET_SIMPLE_1",
]
PRESET_TALEMATE_CONVERSATION = {
"temperature": 0.65,
"top_p": 0.47,
"top_k": 42,
"repetition_penalty": 1.18,
"repetition_penalty_range": 2048,
}
PRESET_TALEMATE_CREATOR = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 20,
"repetition_penalty": 1.15,
"repetition_penalty_range": 512,
}
PRESET_LLAMA_PRECISE = {
'temperature': 0.7,
'top_p': 0.1,
'top_k': 40,
'repetition_penalty': 1.18,
}
PRESET_DIVINE_INTELLECT = {
'temperature': 1.31,
'top_p': 0.14,
'top_k': 49,
"repetition_penalty_range": 1024,
'repetition_penalty': 1.17,
}
PRESET_SIMPLE_1 = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 20,
"repetition_penalty": 1.15,
}
def configure(config:dict, kind:str, total_budget:int):
"""
Sets the config based on the kind of text to generate.
"""
set_preset(config, kind)
set_max_tokens(config, kind, total_budget)
return config
def set_max_tokens(config:dict, kind:str, total_budget:int):
"""
Sets the max_tokens in the config based on the kind of text to generate.
"""
config["max_tokens"] = max_tokens_for_kind(kind, total_budget)
return config
def set_preset(config:dict, kind:str):
"""
Sets the preset in the config based on the kind of text to generate.
"""
config.update(preset_for_kind(kind))
def preset_for_kind(kind: str):
if kind == "conversation":
return PRESET_TALEMATE_CONVERSATION
elif kind == "conversation_old":
return PRESET_TALEMATE_CONVERSATION # Assuming old conversation uses the same preset
elif kind == "conversation_long":
return PRESET_TALEMATE_CONVERSATION # Assuming long conversation uses the same preset
elif kind == "conversation_select_talking_actor":
return PRESET_TALEMATE_CONVERSATION # Assuming select talking actor uses the same preset
elif kind == "summarize":
return PRESET_LLAMA_PRECISE
elif kind == "analyze":
return PRESET_SIMPLE_1
elif kind == "analyze_creative":
return PRESET_DIVINE_INTELLECT
elif kind == "analyze_long":
return PRESET_SIMPLE_1 # Assuming long analysis uses the same preset as simple
elif kind == "analyze_freeform":
return PRESET_LLAMA_PRECISE
elif kind == "analyze_freeform_short":
return PRESET_LLAMA_PRECISE # Assuming short freeform analysis uses the same preset as precise
elif kind == "narrate":
return PRESET_LLAMA_PRECISE
elif kind == "story":
return PRESET_DIVINE_INTELLECT
elif kind == "create":
return PRESET_TALEMATE_CREATOR
elif kind == "create_concise":
return PRESET_TALEMATE_CREATOR # Assuming concise creation uses the same preset as creator
elif kind == "create_precise":
return PRESET_LLAMA_PRECISE
elif kind == "director":
return PRESET_SIMPLE_1
elif kind == "director_short":
return PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
elif kind == "director_yesno":
return PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
elif kind == "edit_dialogue":
return PRESET_DIVINE_INTELLECT
elif kind == "edit_add_detail":
return PRESET_DIVINE_INTELLECT # Assuming adding detail uses the same preset as divine intellect
elif kind == "edit_fix_exposition":
return PRESET_DIVINE_INTELLECT # Assuming fixing exposition uses the same preset as divine intellect
else:
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
def max_tokens_for_kind(kind: str, total_budget: int):
if kind == "conversation":
return 75 # Example value, adjust as needed
elif kind == "conversation_old":
return 75 # Example value, adjust as needed
elif kind == "conversation_long":
return 300 # Example value, adjust as needed
elif kind == "conversation_select_talking_actor":
return 30 # Example value, adjust as needed
elif kind == "summarize":
return 500 # Example value, adjust as needed
elif kind == "analyze":
return 500 # Example value, adjust as needed
elif kind == "analyze_creative":
return 1024 # Example value, adjust as needed
elif kind == "analyze_long":
return 2048 # Example value, adjust as needed
elif kind == "analyze_freeform":
return 500 # Example value, adjust as needed
elif kind == "analyze_freeform_short":
return 10 # Example value, adjust as needed
elif kind == "narrate":
return 500 # Example value, adjust as needed
elif kind == "story":
return 300 # Example value, adjust as needed
elif kind == "create":
return min(1024, int(total_budget * 0.35)) # Example calculation, adjust as needed
elif kind == "create_concise":
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
elif kind == "create_precise":
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
elif kind == "director":
return min(600, int(total_budget * 0.25)) # Example calculation, adjust as needed
elif kind == "director_short":
return 25 # Example value, adjust as needed
elif kind == "director_yesno":
return 2 # Example value, adjust as needed
elif kind == "edit_dialogue":
return 100 # Example value, adjust as needed
elif kind == "edit_add_detail":
return 200 # Example value, adjust as needed
elif kind == "edit_fix_exposition":
return 1024 # Example value, adjust as needed
else:
return 150 # Default value if none of the kinds match

View file

@ -67,9 +67,9 @@ def _client_bootstrap(client_type: ClientType, pod):
id = pod["id"]
if client_type == ClientType.textgen:
api_url = f"https://{id}-5000.proxy.runpod.net/api"
api_url = f"https://{id}-5000.proxy.runpod.net"
elif client_type == ClientType.automatic1111:
api_url = f"https://{id}-5000.proxy.runpod.net/api"
api_url = f"https://{id}-5000.proxy.runpod.net"
return ClientBootstrap(
client_type=client_type,

View file

@ -1,735 +1,61 @@
import asyncio
import random
import json
import copy
import structlog
import time
import httpx
from abc import ABC, abstractmethod
from typing import Callable, Union
import logging
import talemate.util as util
from talemate.client.base import ClientBase, STOPPING_STRINGS
from talemate.client.registry import register
import talemate.client.system_prompts as system_prompts
from talemate.emit import Emission, emit
from talemate.client.context import client_context_attribute
from talemate.client.model_prompts import model_prompt
import talemate.instance as instance
log = structlog.get_logger(__name__)
__all__ = [
"TaleMateClient",
"RestApiTaleMateClient",
"TextGeneratorWebuiClient",
]
# Set up logging level for httpx to WARNING to suppress debug logs.
logging.getLogger('httpx').setLevel(logging.WARNING)
class DefaultContext(int):
pass
PRESET_TALEMATE_LEGACY = {
"temperature": 0.72,
"top_p": 0.73,
"top_k": 0,
"top_a": 0,
"repetition_penalty": 1.18,
"repetition_penalty_range": 2048,
"encoder_repetition_penalty": 1,
#"encoder_repetition_penalty": 1.2,
#"no_repeat_ngram_size": 2,
"do_sample": True,
"length_penalty": 1,
}
PRESET_TALEMATE_CONVERSATION = {
"temperature": 0.65,
"top_p": 0.47,
"top_k": 42,
"typical_p": 1,
"top_a": 0,
"tfs": 1,
"epsilon_cutoff": 0,
"eta_cutoff": 0,
"repetition_penalty": 1.18,
"repetition_penalty_range": 2048,
"no_repeat_ngram_size": 0,
"penalty_alpha": 0,
"num_beams": 1,
"length_penalty": 1,
"min_length": 0,
"encoder_rep_pen": 1,
"do_sample": True,
"early_stopping": False,
"mirostat_mode": 0,
"mirostat_tau": 5,
"mirostat_eta": 0.1
}
PRESET_TALEMATE_CREATOR = {
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.15,
"repetition_penalty_range": 512,
"top_k": 20,
"do_sample": True,
"length_penalty": 1,
}
PRESET_LLAMA_PRECISE = {
'temperature': 0.7,
'top_p': 0.1,
'repetition_penalty': 1.18,
'top_k': 40
}
PRESET_KOBOLD_GODLIKE = {
'temperature': 0.7,
'top_p': 0.5,
'typical_p': 0.19,
'repetition_penalty': 1.1,
"repetition_penalty_range": 1024,
}
PRESET_DIVINE_INTELLECT = {
'temperature': 1.31,
'top_p': 0.14,
"repetition_penalty_range": 1024,
'repetition_penalty': 1.17,
'top_k': 49,
"mirostat_mode": 0,
"mirostat_tau": 5,
"mirostat_eta": 0.1,
"tfs": 1,
}
PRESET_SIMPLE_1 = {
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.15,
"top_k": 20,
}
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
rep_pen = prompt_config["repetition_penalty"]
copied_config = copy.deepcopy(prompt_config)
min_offset = offset * 0.3
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
return copied_config
class TaleMateClient:
"""
An abstract TaleMate client that can be implemented for different communication methods with the AI.
"""
def __init__(
self,
api_url: str,
max_token_length: Union[int, DefaultContext] = int.__new__(
DefaultContext, 2048
),
):
self.api_url = api_url
self.name = "generic_client"
self.model_name = None
self.last_token_length = 0
self.max_token_length = max_token_length
self.original_max_token_length = max_token_length
self.enabled = True
self.current_status = None
@abstractmethod
def send_message(self, message: dict) -> str:
"""
Sends a message to the AI. Needs to be implemented by the subclass.
:param message: The message to be sent.
:return: The AI's response text.
"""
pass
@abstractmethod
def send_prompt(self, prompt: str) -> str:
"""
Sends a prompt to the AI. Needs to be implemented by the subclass.
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
pass
def reconfigure(self, **kwargs):
if "api_url" in kwargs:
self.api_url = kwargs["api_url"]
if "max_token_length" in kwargs:
self.max_token_length = kwargs["max_token_length"]
if "enabled" in kwargs:
self.enabled = bool(kwargs["enabled"])
def remaining_tokens(self, context: Union[str, list]) -> int:
return self.max_token_length - util.count_tokens(context)
def prompt_template(self, sys_msg, prompt):
return model_prompt(self.model_name, sys_msg, prompt)
class RESTTaleMateClient(TaleMateClient, ABC):
"""
A RESTful TaleMate client that connects to the REST API endpoint.
"""
async def send_message(self, message: dict, url: str) -> str:
"""
Sends a message to the REST API and returns the AI's response.
:param message: The message to be sent.
:return: The AI's response text.
"""
try:
async with httpx.AsyncClient() as client:
response = await client.post(url, json=message, timeout=None)
response_data = response.json()
return response_data["results"][0]["text"]
except KeyError:
return response_data["results"][0]["history"]["visible"][0][-1]
from openai import AsyncOpenAI
import httpx
import copy
import random
@register()
class TextGeneratorWebuiClient(RESTTaleMateClient):
"""
Client that connects to the text-generatior-webui api
"""
class TextGeneratorWebuiClient(ClientBase):
client_type = "textgenwebui"
conversation_retries = 5
def __init__(self, api_url: str, max_token_length: int = 2048, **kwargs):
def tune_prompt_parameters(self, parameters:dict, kind:str):
super().tune_prompt_parameters(parameters, kind)
parameters["stopping_strings"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
# is this needed?
parameters["max_new_tokens"] = parameters["max_tokens"]
api_url = self.cleanup_api_url(api_url)
def set_client(self):
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
self.api_url_base = api_url
api_url = f"{api_url}/v1/chat"
super().__init__(api_url, max_token_length=max_token_length)
self.model_name = None
self.limited_ram = False
self.name = kwargs.get("name", "textgenwebui")
self.processing = False
self.connected = False
async def get_model_name(self):
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_url}/v1/internal/model/info", timeout=2)
if response.status_code == 404:
raise Exception("Could not find model info (wrong api version?)")
response_data = response.json()
model_name = response_data.get("model_name")
return model_name
def __str__(self):
return f"TextGeneratorWebuiClient[{self.api_url_base}][{self.model_name or ''}]"
def cleanup_api_url(self, api_url:str):
async def generate(self, prompt:str, parameters:dict, kind:str):
"""
Strips trailing / and ensures endpoint is /api
Generates text from the given prompt and parameters.
"""
if api_url.endswith("/"):
api_url = api_url[:-1]
headers = {}
headers["Content-Type"] = "application/json"
if not api_url.endswith("/api"):
api_url = api_url + "/api"
parameters["prompt"] = prompt.strip()
return api_url
def reconfigure(self, **kwargs):
super().reconfigure(**kwargs)
if "api_url" in kwargs:
log.debug("reconfigure", api_url=kwargs["api_url"])
api_url = kwargs["api_url"]
api_url = self.cleanup_api_url(api_url)
self.api_url_base = api_url
self.api_url = api_url
def toggle_disabled_if_remote(self):
remote_servies = [
".runpod.net"
]
for service in remote_servies:
if service in self.api_url_base:
self.enabled = False
return
def emit_status(self, processing: bool = None):
if processing is not None:
self.processing = processing
if not self.enabled:
status = "disabled"
model_name = "Disabled"
elif not self.connected:
status = "error"
model_name = "Could not connect"
elif self.model_name:
status = "busy" if self.processing else "idle"
model_name = self.model_name
else:
model_name = "No model loaded"
status = "warning"
status_change = status != self.current_status
self.current_status = status
emit(
"client_status",
message=self.client_type,
id=self.name,
details=model_name,
status=status,
)
if status_change:
instance.emit_agent_status_by_client(self)
# Add the 'status' method
async def status(self):
"""
Send a request to the API to retrieve the loaded AI model name.
Raises an error if no model name is returned.
:return: None
"""
if not self.enabled:
self.connected = False
self.emit_status()
return
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_url_base}/v1/model", timeout=2)
except (
httpx.TimeoutException,
httpx.NetworkError,
):
self.model_name = None
self.connected = False
self.toggle_disabled_if_remote()
self.emit_status()
return
self.connected = True
try:
async with httpx.AsyncClient() as client:
response = await client.post(f"{self.api_url}/v1/completions", json=parameters, timeout=None, headers=headers)
response_data = response.json()
self.enabled = True
except json.decoder.JSONDecodeError as e:
self.connected = False
self.toggle_disabled_if_remote()
if not self.enabled:
log.warn("remote service unreachable, disabling client", name=self.name)
else:
log.error("client response error", name=self.name, e=e)
return response_data["choices"][0]["text"]
self.emit_status()
return
model_name = response_data.get("result")
if not model_name or model_name == "None":
log.warning("client model not loaded", client=self.name)
self.emit_status()
return
model_changed = model_name != self.model_name
self.model_name = model_name
if model_changed:
self.auto_context_length()
log.info(f"{self} [{self.max_token_length} ctx]: ready")
self.emit_status()
def auto_context_length(self):
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
"""
Automaticalle sets context length based on LLM
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
if not isinstance(self.max_token_length, DefaultContext):
# context length was specified manually
return
temp = prompt_config["temperature"]
rep_pen = prompt_config["repetition_penalty"]
model_name = self.model_name.lower()
min_offset = offset * 0.3
if "longchat" in model_name:
self.max_token_length = 16000
elif "8k" in model_name:
if not self.limited_ram or "13b" in model_name:
self.max_token_length = 6000
else:
self.max_token_length = 4096
elif "4k" in model_name:
self.max_token_length = 4096
else:
self.max_token_length = self.original_max_token_length
@property
def instruction_template(self):
if "vicuna" in self.model_name.lower():
return "Vicuna-v1.1"
if "camel" in self.model_name.lower():
return "Vicuna-v1.1"
return ""
def prompt_url(self):
return self.api_url_base + "/v1/generate"
def prompt_config_conversation_old(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.BASIC,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 75,
"truncation_length": self.max_token_length,
}
config.update(PRESET_TALEMATE_CONVERSATION)
return config
def prompt_config_conversation(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.ROLEPLAY,
prompt,
)
stopping_strings = ["<|end_of_turn|>"]
conversation_context = client_context_attribute("conversation")
stopping_strings += [
f"{character}:" for character in conversation_context["other_characters"]
]
max_new_tokens = conversation_context.get("length", 96)
log.debug("prompt_config_conversation", stopping_strings=stopping_strings, conversation_context=conversation_context, max_new_tokens=max_new_tokens)
config = {
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"truncation_length": self.max_token_length,
"stopping_strings": stopping_strings,
}
config.update(PRESET_TALEMATE_CONVERSATION)
jiggle_randomness(config)
return config
def prompt_config_conversation_long(self, prompt: str) -> dict:
config = self.prompt_config_conversation(prompt)
config["max_new_tokens"] = 300
return config
def prompt_config_conversation_select_talking_actor(self, prompt: str) -> dict:
config = self.prompt_config_conversation(prompt)
config["max_new_tokens"] = 30
config["stopping_strings"] += [":"]
return config
def prompt_config_summarize(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.NARRATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 500,
"truncation_length": self.max_token_length,
}
config.update(PRESET_LLAMA_PRECISE)
return config
def prompt_config_analyze(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.ANALYST,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 500,
"truncation_length": self.max_token_length,
}
config.update(PRESET_SIMPLE_1)
return config
def prompt_config_analyze_creative(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.ANALYST,
prompt,
)
config = {}
config.update(PRESET_DIVINE_INTELLECT)
config.update({
"prompt": prompt,
"max_new_tokens": 1024,
"repetition_penalty_range": 1024,
"truncation_length": self.max_token_length
})
return config
def prompt_config_analyze_long(self, prompt: str) -> dict:
config = self.prompt_config_analyze(prompt)
config["max_new_tokens"] = 2048
return config
def prompt_config_analyze_freeform(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.ANALYST_FREEFORM,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 500,
"truncation_length": self.max_token_length,
}
config.update(PRESET_LLAMA_PRECISE)
return config
def prompt_config_analyze_freeform_short(self, prompt: str) -> dict:
config = self.prompt_config_analyze_freeform(prompt)
config["max_new_tokens"] = 10
return config
def prompt_config_narrate(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.NARRATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 500,
"truncation_length": self.max_token_length,
}
config.update(PRESET_LLAMA_PRECISE)
return config
def prompt_config_story(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.NARRATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": 300,
"seed": random.randint(0, 1000000000),
"truncation_length": self.max_token_length
}
config.update(PRESET_DIVINE_INTELLECT)
config.update({
"repetition_penalty": 1.3,
"repetition_penalty_range": 2048,
})
return config
def prompt_config_create(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.CREATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": min(1024, self.max_token_length * 0.35),
"truncation_length": self.max_token_length,
}
config.update(PRESET_TALEMATE_CREATOR)
return config
def prompt_config_create_concise(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.CREATOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": min(400, self.max_token_length * 0.25),
"truncation_length": self.max_token_length,
"stopping_strings": ["<|DONE|>", "\n\n"]
}
config.update(PRESET_TALEMATE_CREATOR)
return config
def prompt_config_create_precise(self, prompt: str) -> dict:
config = self.prompt_config_create_concise(prompt)
config.update(PRESET_LLAMA_PRECISE)
return config
def prompt_config_director(self, prompt: str) -> dict:
prompt = self.prompt_template(
system_prompts.DIRECTOR,
prompt,
)
config = {
"prompt": prompt,
"max_new_tokens": min(600, self.max_token_length * 0.25),
"truncation_length": self.max_token_length,
}
config.update(PRESET_SIMPLE_1)
return config
def prompt_config_director_short(self, prompt: str) -> dict:
config = self.prompt_config_director(prompt)
config.update(max_new_tokens=25)
return config
def prompt_config_director_yesno(self, prompt: str) -> dict:
config = self.prompt_config_director(prompt)
config.update(max_new_tokens=2)
return config
def prompt_config_edit_dialogue(self, prompt:str) -> dict:
prompt = self.prompt_template(
system_prompts.EDITOR,
prompt,
)
conversation_context = client_context_attribute("conversation")
stopping_strings = [
f"{character}:" for character in conversation_context["other_characters"]
]
config = {
"prompt": prompt,
"max_new_tokens": 100,
"truncation_length": self.max_token_length,
"stopping_strings": stopping_strings,
}
config.update(PRESET_DIVINE_INTELLECT)
return config
def prompt_config_edit_add_detail(self, prompt:str) -> dict:
config = self.prompt_config_edit_dialogue(prompt)
config.update(max_new_tokens=200)
return config
def prompt_config_edit_fix_exposition(self, prompt:str) -> dict:
config = self.prompt_config_edit_dialogue(prompt)
config.update(max_new_tokens=1024)
return config
async def send_prompt(
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
) -> str:
"""
Send a prompt to the AI and return its response.
:param prompt: The text prompt to send.
:return: The AI's response text.
"""
#prompt = prompt.replace("<|BOT|>", "<|BOT|>Certainly! ")
await self.status()
self.emit_status(processing=True)
await asyncio.sleep(0.01)
fn_prompt_config = getattr(self, f"prompt_config_{kind}")
fn_url = self.prompt_url
message = fn_prompt_config(prompt)
if client_context_attribute("nuke_repetition") > 0.0 and kind in ["conversation", "story"]:
log.info("nuke repetition", offset=client_context_attribute("nuke_repetition"), temperature=message["temperature"], repetition_penalty=message["repetition_penalty"])
message = jiggle_randomness(message, offset=client_context_attribute("nuke_repetition"))
log.info("nuke repetition (applied)", offset=client_context_attribute("nuke_repetition"), temperature=message["temperature"], repetition_penalty=message["repetition_penalty"])
message = finalize(message)
token_length = int(len(message["prompt"]) / 3.6)
self.last_token_length = token_length
log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length)
message["prompt"] = message["prompt"].strip()
#print(f"prompt: |{message['prompt']}|")
# add <|im_end|> to stopping strings
if "stopping_strings" in message:
message["stopping_strings"] += ["<|im_end|>", "</s>"]
else:
message["stopping_strings"] = ["<|im_end|>", "</s>"]
#message["seed"] = -1
#for k,v in message.items():
# if k == "prompt":
# continue
# print(f"{k}: {v}")
time_start = time.time()
response = await self.send_message(message, fn_url())
time_end = time.time()
response = response.split("#")[0]
self.emit_status(processing=False)
emit("prompt_sent", data={
"kind": kind,
"prompt": message["prompt"],
"response": response,
"prompt_tokens": token_length,
"response_tokens": int(len(response) / 3.6),
"time": time_end - time_start,
})
return response
class OpenAPIClient(RESTTaleMateClient):
pass
class GPT3Client(OpenAPIClient):
pass
class GPT4Client(OpenAPIClient):
pass
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
prompt_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)

View file

@ -0,0 +1,32 @@
import copy
import random
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
"""
adjusts temperature and repetition_penalty
by random values using the base value as a center
"""
temp = prompt_config["temperature"]
rep_pen = prompt_config["repetition_penalty"]
copied_config = copy.deepcopy(prompt_config)
min_offset = offset * 0.3
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
return copied_config
def jiggle_enabled_for(kind:str):
if kind in ["conversation", "story"]:
return True
if kind.startswith("narrate"):
return True
return False

View file

@ -17,7 +17,26 @@ class CmdRename(TalemateCommand):
aliases = []
async def run(self):
# collect list of characters in the scene
if self.args:
character_name = self.args[0]
else:
character_names = self.scene.character_names
character_name = await wait_for_input("Which character do you want to rename?", data={
"input_type": "select",
"choices": character_names,
})
character = self.scene.get_character(character_name)
if not character:
self.system_message(f"Character {character_name} not found")
return True
name = await wait_for_input("Enter new name: ")
self.scene.main_character.character.rename(name)
character.rename(name)
await asyncio.sleep(0)
return True

View file

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from talemate.tale_mate import Scene
from talemate.tale_mate import Scene, Actor
__all__ = [
"Event",
@ -43,3 +43,7 @@ class GameLoopEvent(Event):
@dataclass
class GameLoopStartEvent(GameLoopEvent):
pass
@dataclass
class GameLoopActorIterEvent(GameLoopEvent):
actor: Actor

View file

@ -290,6 +290,7 @@ class Prompt:
env.globals["query_scene"] = self.query_scene
env.globals["query_memory"] = self.query_memory
env.globals["query_text"] = self.query_text
env.globals["instruct_text"] = self.instruct_text
env.globals["retrieve_memories"] = self.retrieve_memories
env.globals["uuidgen"] = lambda: str(uuid.uuid4())
env.globals["to_int"] = lambda x: int(x)
@ -394,9 +395,14 @@ class Prompt:
f"Answer: " + loop.run_until_complete(memory.query(query, **kwargs)),
])
else:
return loop.run_until_complete(memory.multi_query([query], **kwargs))
return loop.run_until_complete(memory.multi_query(query.split("\n"), **kwargs))
def instruct_text(self, instruction:str, text:str):
loop = asyncio.get_event_loop()
world_state = instance.get_agent("world_state")
instruction = instruction.format(**self.vars)
return loop.run_until_complete(world_state.analyze_and_follow_instruction(text, instruction))
def retrieve_memories(self, lines:list[str], goal:str=None):

View file

@ -0,0 +1,19 @@
{% block rendered_context -%}
<|SECTION:CONTEXT|>
Content Context: This is a specific scene from {{ scene.context }}
Scenario Premise: {{ scene.description }}
{% for memory in query_memory(last_line, as_question_answer=False, iterate=10) -%}
{{ memory }}
{% endfor %}
{% endblock -%}
<|CLOSE_SECTION|>
{% for scene_context in scene.context_history(budget=max_tokens-200-count_tokens(self.rendered_context())) -%}
{{ scene_context }}
{% endfor %}
<|SECTION:TASK|>
Based on the previous line '{{ last_line }}', create the next line of narration. This line should focus solely on describing sensory details (like sounds, sights, smells, tactile sensations) or external actions that move the story forward. Avoid including any character's internal thoughts, feelings, or dialogue. Your narration should directly respond to '{{ last_line }}', either by elaborating on the immediate scene or by subtly advancing the plot. Generate exactly one sentence of new narration. If the character is trying to determine some state, truth or situation, try to answer as part of the narration.
Be creative and generate something new and interesting.
<|CLOSE_SECTION|>
{{ set_prepared_response('*') }}

View file

@ -8,13 +8,13 @@
{% if query.endswith("?") -%}
Question: {{ query }}
Extra context: {{ query_memory(query, as_question_answer=False) }}
Instruction: Analyze Context, History and Dialogue. Be factual and truthful. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Respect the scene progression and answer in the context of the end of the dialogue.
Instruction: Analyze Context, History and Dialogue. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Respect the scene progression and answer in the context of the end of the dialogue.
{% else -%}
Instruction: {{ query }}
Extra context: {{ query_memory(query, as_question_answer=False) }}
Answer based on Context, History and Dialogue. Be factual and truthful. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context.
Answer based on Context, History and Dialogue. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context.
{% endif -%}
Content Context: This is a specific scene from {{ scene.context }}
Narration style: point and click adventure game from the 90s
Your answer should be in the style of short narration that fits the context of the scene.
<|CLOSE_SECTION|>
Narrator answers: {% if at_the_end %}{{ bot_token }}At the end of the dialogue, {% endif %}

View file

@ -8,9 +8,10 @@
<|SECTION:TASK|>
Answer the following questions:
{{ query_text("What are 1 to 3 questions to ask the narrator of the story to gather more context from the past for the continuation of this conversation? If a character is asking about a status, location or information about an item or another character, make sure to include question(s) that help gather context for this. Don't explain your reasoning. Don't ask the actors directly.", text, as_question_answer=False) }}
{{ instruct_text("Ask the narrator three (3) questions to gather more context from the past for the continuation of this conversation. If a character is asking about a state, location or information about an item or another character, make sure to include question(s) that help gather context for this.", text) }}
You answers should be precise, truthful and short.
You answers should be precise, truthful and short. Pay close attention to timestamps when retrieving information from the context.
<|CLOSE_SECTION|>
<|SECTION:RELEVANT CONTEXT|>
{{ bot_token }}Answers:

View file

@ -0,0 +1,5 @@
{{ text }}
<|SECTION:TASK|>
{{ instruction }}

View file

@ -34,7 +34,7 @@ No dialogue so far
{% endif -%}
<|CLOSE_SECTION|>
<|SECTION:SCENE PROGRESS|>
{% for scene_context in scene.context_history(budget=300, min_dialogue=5, add_archieved_history=False, max_dialogue=5) -%}
{% for scene_context in scene.context_history(budget=500, min_dialogue=5, add_archieved_history=False, max_dialogue=5) -%}
{{ scene_context }}
{% endfor -%}
<|CLOSE_SECTION|>

View file

@ -1,3 +1,5 @@
import os
import argparse
import asyncio
import sys

View file

@ -167,14 +167,14 @@ class WebsocketHandler(Receiver):
log.info("Configuring clients", clients=clients)
for client in clients:
if client["type"] == "textgenwebui":
if client["type"] in ["textgenwebui", "lmstudio"]:
try:
max_token_length = int(client.get("max_token_length", 2048))
except ValueError:
continue
self.llm_clients[client["name"]] = {
"type": "textgenwebui",
"type": client["type"],
"api_url": client["apiUrl"],
"name": client["name"],
"max_token_length": max_token_length,
@ -385,7 +385,7 @@ class WebsocketHandler(Receiver):
"status": emission.status,
"data": emission.data,
"max_token_length": client.max_token_length if client else 2048,
"apiUrl": getattr(client, "api_url_base", None) if client else None,
"apiUrl": getattr(client, "api_url", None) if client else None,
}
)

View file

@ -43,6 +43,10 @@ __all__ = [
log = structlog.get_logger("talemate")
async_signals.register("game_loop_start")
async_signals.register("game_loop")
async_signals.register("game_loop_actor_iter")
class Character:
"""
@ -523,8 +527,6 @@ class Player(Actor):
return message
async_signals.register("game_loop_start")
async_signals.register("game_loop")
class Scene(Emitter):
"""
@ -575,6 +577,7 @@ class Scene(Emitter):
"character_state": signal("character_state"),
"game_loop": async_signals.get("game_loop"),
"game_loop_start": async_signals.get("game_loop_start"),
"game_loop_actor_iter": async_signals.get("game_loop_actor_iter"),
}
self.setup_emitter(scene=self)
@ -1066,7 +1069,9 @@ class Scene(Emitter):
new_message = await narrator.agent.narrate_character(character)
elif source == "narrate_query":
new_message = await narrator.agent.narrate_query(arg)
elif source == "narrate_dialogue":
character = self.get_character(arg)
new_message = await narrator.agent.narrate_after_dialogue(character)
else:
fn = getattr(narrator.agent, source, None)
if not fn:
@ -1339,6 +1344,10 @@ class Scene(Emitter):
if await command.execute(message):
break
await self.call_automated_actions()
await self.signals["game_loop_actor_iter"].send(
events.GameLoopActorIterEvent(scene=self, event_type="game_loop_actor_iter", actor=actor)
)
continue
self.saved = False
@ -1351,6 +1360,10 @@ class Scene(Emitter):
"character", item, character=actor.character
)
await self.signals["game_loop_actor_iter"].send(
events.GameLoopActorIterEvent(scene=self, event_type="game_loop_actor_iter", actor=actor)
)
self.emit_status()
except TalemateInterrupt:

View file

@ -303,6 +303,9 @@ def strip_partial_sentences(text:str) -> str:
# Sentence ending characters
sentence_endings = ['.', '!', '?', '"', "*"]
if not text:
return text
# Check if the last character is already a sentence ending
if text[-1] in sentence_endings:
return text
@ -779,7 +782,11 @@ def ensure_dialog_format(line:str, talking_character:str=None) -> str:
lines = []
for _line in line.split("\n"):
_line = ensure_dialog_line_format(_line)
try:
_line = ensure_dialog_line_format(_line)
except Exception as exc:
log.error("ensure_dialog_format", msg="Error ensuring dialog line format", line=_line, exc_info=exc)
pass
lines.append(_line)

View file

@ -120,7 +120,7 @@ export default {
this.state.currentClient = {
name: 'TextGenWebUI',
type: 'textgenwebui',
apiUrl: 'http://localhost:5000/api',
apiUrl: 'http://localhost:5000',
model_name: '',
max_token_length: 4096,
};

View file

@ -8,7 +8,7 @@
<v-container>
<v-row>
<v-col cols="6">
<v-select v-model="client.type" :items="['openai', 'textgenwebui']" label="Client Type"></v-select>
<v-select v-model="client.type" :items="['openai', 'textgenwebui', 'lmstudio']" label="Client Type"></v-select>
</v-col>
<v-col cols="6">
<v-text-field v-model="client.name" label="Client Name"></v-text-field>
@ -17,13 +17,13 @@
</v-row>
<v-row>
<v-col cols="12">
<v-text-field v-model="client.apiUrl" v-if="client.type === 'textgenwebui'" label="API URL"></v-text-field>
<v-text-field v-model="client.apiUrl" v-if="isLocalApiClient(client)" label="API URL"></v-text-field>
<v-select v-model="client.model" v-if="client.type === 'openai'" :items="['gpt-4-1106-preview', 'gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']" label="Model"></v-select>
</v-col>
</v-row>
<v-row>
<v-col cols="6">
<v-text-field v-model="client.max_token_length" v-if="client.type === 'textgenwebui'" type="number" label="Context Length"></v-text-field>
<v-text-field v-model="client.max_token_length" v-if="isLocalApiClient(client)" type="number" label="Context Length"></v-text-field>
</v-col>
</v-row>
</v-container>
@ -74,6 +74,9 @@ export default {
save() {
this.$emit('save', this.client); // Emit save event with client object
this.close();
},
isLocalApiClient(client) {
return client.type === 'textgenwebui' || client.type === 'lmstudio';
}
}
}

View file

@ -0,0 +1,4 @@
{{ system_message }}
### Instruction:
{{ set_response(prompt, "\n\n### Response:\n") }}

View file

@ -0,0 +1,3 @@
USER:
{{ system_message }}
{{ set_response(prompt, "\nASSISTANT:") }}

View file

@ -0,0 +1,4 @@
{{ system_message }}
### Instruction:
{{ set_response(prompt, "\n\n### Response:\n") }}

View file

@ -0,0 +1,2 @@
SYSTEM: {{ system_message }}
USER: {{ set_response(prompt, "\nASSISTANT: ") }}

View file

@ -0,0 +1,4 @@
<|im_start|>system
{{ system_message }}<|im_end|>
<|im_start|>user
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}

View file

@ -0,0 +1,4 @@
<|im_start|>system
{{ system_message }}<|im_end|>
<|im_start|>user
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}

View file

@ -6,5 +6,5 @@ call talemate_env\Scripts\activate
REM use poetry to install dependencies
python -m poetry install
echo Virtual environment re-created.
echo Virtual environment updated
pause