mirror of
https://github.com/vegu-ai/talemate.git
synced 2025-09-01 18:09:08 +00:00
Prep 0.13.0 (#28)
* requirements.txt file * windows installs from requirements.txt because of silly permission issues * relock * narrator - narrate on dialogue agent actions * add support for new textgenwebui api * world state auto regen trigger off of gameloop * funciton !rename command * ensure_dialog_format error handling * Cat, Nous-Capybara, dolphin-2.2.1 * narrate after dialog rerun fixes, template fixes * LMStudio client (experimental) * dolhpin yi * refactor client base * cruft * openai client to new base * more client refactor fixes * tweak context retrieval prompts * adjust nous capybara template * add Tess-Medium * 0.13.0 * switch back to poetry for windows as well * error on legacy textgenwebui api * runpod text gen api url fixed * fix windows install script * add fllow instruction template * Psyfighter2
This commit is contained in:
parent
f9b23f8705
commit
d7e72d27c5
37 changed files with 1315 additions and 875 deletions
|
@ -7,10 +7,10 @@ REM activate the virtual environment
|
||||||
call talemate_env\Scripts\activate
|
call talemate_env\Scripts\activate
|
||||||
|
|
||||||
REM install poetry
|
REM install poetry
|
||||||
python -m pip install poetry "rapidfuzz>=3" -U
|
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
|
||||||
|
|
||||||
REM use poetry to install dependencies
|
REM use poetry to install dependencies
|
||||||
poetry install
|
python -m poetry install
|
||||||
|
|
||||||
REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist
|
REM copy config.example.yaml to config.yaml only if config.yaml doesn't exist
|
||||||
IF NOT EXIST config.yaml copy config.example.yaml config.yaml
|
IF NOT EXIST config.yaml copy config.example.yaml config.yaml
|
||||||
|
|
383
poetry.lock
generated
383
poetry.lock
generated
|
@ -344,17 +344,17 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "boto3"
|
name = "boto3"
|
||||||
version = "1.28.83"
|
version = "1.28.84"
|
||||||
description = "The AWS SDK for Python"
|
description = "The AWS SDK for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">= 3.7"
|
python-versions = ">= 3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "boto3-1.28.83-py3-none-any.whl", hash = "sha256:1d10691911c4b8b9443d3060257ba32b68b6e3cad0eebbb9f69fd1c52a78417f"},
|
{file = "boto3-1.28.84-py3-none-any.whl", hash = "sha256:98b01bbea27740720a06f7c7bc0132ae4ce902e640aab090cfb99ad3278449c3"},
|
||||||
{file = "boto3-1.28.83.tar.gz", hash = "sha256:489c4967805b677b7a4030460e4c06c0903d6bc0f6834453611bf87efbd8d8a3"},
|
{file = "boto3-1.28.84.tar.gz", hash = "sha256:adfb915958d7b54d876891ea1599dd83189e35a2442eb41ca52b04ea716180b6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
botocore = ">=1.31.83,<1.32.0"
|
botocore = ">=1.31.84,<1.32.0"
|
||||||
jmespath = ">=0.7.1,<2.0.0"
|
jmespath = ">=0.7.1,<2.0.0"
|
||||||
s3transfer = ">=0.7.0,<0.8.0"
|
s3transfer = ">=0.7.0,<0.8.0"
|
||||||
|
|
||||||
|
@ -363,13 +363,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "botocore"
|
name = "botocore"
|
||||||
version = "1.31.83"
|
version = "1.31.84"
|
||||||
description = "Low-level, data-driven core of boto 3."
|
description = "Low-level, data-driven core of boto 3."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">= 3.7"
|
python-versions = ">= 3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "botocore-1.31.83-py3-none-any.whl", hash = "sha256:c742069e8bfd06d212d712228258ff09fb481b6ec02358e539381ce0fcad065a"},
|
{file = "botocore-1.31.84-py3-none-any.whl", hash = "sha256:d65bc05793d1a8a8c191a739f742876b4b403c5c713dc76beef262d18f7984a2"},
|
||||||
{file = "botocore-1.31.83.tar.gz", hash = "sha256:40914b0fb28f13d709e1f8a4481e278350b77a3987be81acd23715ec8d5fedca"},
|
{file = "botocore-1.31.84.tar.gz", hash = "sha256:8913bedb96ad0427660dee083aeaa675466eb662bbf1a47781956b5882aadcc5"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -378,7 +378,18 @@ python-dateutil = ">=2.1,<3.0.0"
|
||||||
urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}
|
urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
crt = ["awscrt (==0.16.26)"]
|
crt = ["awscrt (==0.19.10)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cachetools"
|
||||||
|
version = "5.3.2"
|
||||||
|
description = "Extensible memoizing collections and decorators"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"},
|
||||||
|
{file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
|
@ -529,13 +540,13 @@ numpy = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "chromadb"
|
name = "chromadb"
|
||||||
version = "0.4.14"
|
version = "0.4.17"
|
||||||
description = "Chroma."
|
description = "Chroma."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "chromadb-0.4.14-py3-none-any.whl", hash = "sha256:c1b59bdfb4b35a40bad0b8927c5ed757adf191ff9db2b9a384dc46a76e1ff10f"},
|
{file = "chromadb-0.4.17-py3-none-any.whl", hash = "sha256:8cb88162bc6124441ba5a4b93819463a10e9aaafbe05a3286e876cbdc7a7e11d"},
|
||||||
{file = "chromadb-0.4.14.tar.gz", hash = "sha256:0fcef603bcf9c854305020c3f8d368c09b1545d48bd2bceefd51861090f87dad"},
|
{file = "chromadb-0.4.17.tar.gz", hash = "sha256:120f9d364719b664d5314500f8e6097f0e0b24496bb97a429bc324f8d11f1b52"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -544,14 +555,20 @@ chroma-hnswlib = "0.7.3"
|
||||||
fastapi = ">=0.95.2"
|
fastapi = ">=0.95.2"
|
||||||
grpcio = ">=1.58.0"
|
grpcio = ">=1.58.0"
|
||||||
importlib-resources = "*"
|
importlib-resources = "*"
|
||||||
|
kubernetes = ">=28.1.0"
|
||||||
numpy = {version = ">=1.22.5", markers = "python_version >= \"3.8\""}
|
numpy = {version = ">=1.22.5", markers = "python_version >= \"3.8\""}
|
||||||
onnxruntime = ">=1.14.1"
|
onnxruntime = ">=1.14.1"
|
||||||
|
opentelemetry-api = ">=1.2.0"
|
||||||
|
opentelemetry-exporter-otlp-proto-grpc = ">=1.2.0"
|
||||||
|
opentelemetry-sdk = ">=1.2.0"
|
||||||
overrides = ">=7.3.1"
|
overrides = ">=7.3.1"
|
||||||
posthog = ">=2.4.0"
|
posthog = ">=2.4.0"
|
||||||
pulsar-client = ">=3.1.0"
|
pulsar-client = ">=3.1.0"
|
||||||
pydantic = ">=1.9"
|
pydantic = ">=1.9"
|
||||||
pypika = ">=0.48.9"
|
pypika = ">=0.48.9"
|
||||||
|
PyYAML = ">=6.0.0"
|
||||||
requests = ">=2.28"
|
requests = ">=2.28"
|
||||||
|
tenacity = ">=8.2.3"
|
||||||
tokenizers = ">=0.13.2"
|
tokenizers = ">=0.13.2"
|
||||||
tqdm = ">=4.65.0"
|
tqdm = ">=4.65.0"
|
||||||
typer = ">=0.9.0"
|
typer = ">=0.9.0"
|
||||||
|
@ -600,6 +617,23 @@ humanfriendly = ">=9.1"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
cron = ["capturer (>=2.4)"]
|
cron = ["capturer (>=2.4)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "deprecated"
|
||||||
|
version = "1.2.14"
|
||||||
|
description = "Python @deprecated decorator to deprecate old python classes, functions or methods."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||||
|
files = [
|
||||||
|
{file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"},
|
||||||
|
{file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
wrapt = ">=1.10,<2"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "distro"
|
name = "distro"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
|
@ -822,6 +856,46 @@ smb = ["smbprotocol"]
|
||||||
ssh = ["paramiko"]
|
ssh = ["paramiko"]
|
||||||
tqdm = ["tqdm"]
|
tqdm = ["tqdm"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "google-auth"
|
||||||
|
version = "2.23.4"
|
||||||
|
description = "Google Authentication Library"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "google-auth-2.23.4.tar.gz", hash = "sha256:79905d6b1652187def79d491d6e23d0cbb3a21d3c7ba0dbaa9c8a01906b13ff3"},
|
||||||
|
{file = "google_auth-2.23.4-py2.py3-none-any.whl", hash = "sha256:d4bbc92fe4b8bfd2f3e8d88e5ba7085935da208ee38a134fc280e7ce682a05f2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
cachetools = ">=2.0.0,<6.0"
|
||||||
|
pyasn1-modules = ">=0.2.1"
|
||||||
|
rsa = ">=3.1.4,<5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"]
|
||||||
|
enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
|
||||||
|
pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
|
||||||
|
reauth = ["pyu2f (>=0.1.5)"]
|
||||||
|
requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "googleapis-common-protos"
|
||||||
|
version = "1.61.0"
|
||||||
|
description = "Common protobufs used in Google APIs"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "googleapis-common-protos-1.61.0.tar.gz", hash = "sha256:8a64866a97f6304a7179873a465d6eee97b7a24ec6cfd78e0f575e96b821240b"},
|
||||||
|
{file = "googleapis_common_protos-1.61.0-py2.py3-none-any.whl", hash = "sha256:22f1915393bb3245343f6efe87f6fe868532efc12aa26b391b15132e1279f1c0"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpcio"
|
name = "grpcio"
|
||||||
version = "1.59.2"
|
version = "1.59.2"
|
||||||
|
@ -1050,6 +1124,25 @@ files = [
|
||||||
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
|
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "importlib-metadata"
|
||||||
|
version = "6.8.0"
|
||||||
|
description = "Read metadata from Python packages"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"},
|
||||||
|
{file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
zipp = ">=0.5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||||
|
perf = ["ipython"]
|
||||||
|
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "importlib-resources"
|
name = "importlib-resources"
|
||||||
version = "6.1.1"
|
version = "6.1.1"
|
||||||
|
@ -1187,6 +1280,32 @@ files = [
|
||||||
{file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
|
{file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "kubernetes"
|
||||||
|
version = "28.1.0"
|
||||||
|
description = "Kubernetes python client"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "kubernetes-28.1.0-py2.py3-none-any.whl", hash = "sha256:10f56f8160dcb73647f15fafda268e7f60cf7dbc9f8e46d52fcd46d3beb0c18d"},
|
||||||
|
{file = "kubernetes-28.1.0.tar.gz", hash = "sha256:1468069a573430fb1cb5ad22876868f57977930f80a6749405da31cd6086a7e9"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
certifi = ">=14.05.14"
|
||||||
|
google-auth = ">=1.0.1"
|
||||||
|
oauthlib = ">=3.2.2"
|
||||||
|
python-dateutil = ">=2.5.3"
|
||||||
|
pyyaml = ">=5.4.1"
|
||||||
|
requests = "*"
|
||||||
|
requests-oauthlib = "*"
|
||||||
|
six = ">=1.9.0"
|
||||||
|
urllib3 = ">=1.24.2,<2.0"
|
||||||
|
websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
adal = ["adal (>=1.0.2)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lazy-object-proxy"
|
name = "lazy-object-proxy"
|
||||||
version = "1.9.0"
|
version = "1.9.0"
|
||||||
|
@ -1259,16 +1378,6 @@ files = [
|
||||||
{file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
|
{file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
|
||||||
{file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
|
{file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
|
||||||
{file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
|
{file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"},
|
|
||||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
|
{file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
|
||||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
|
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
|
||||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
|
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
|
||||||
|
@ -1551,6 +1660,22 @@ files = [
|
||||||
{file = "numpy-1.25.2.tar.gz", hash = "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760"},
|
{file = "numpy-1.25.2.tar.gz", hash = "sha256:fd608e19c8d7c55021dffd43bfe5492fab8cc105cc8986f813f8c3c048b38760"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "oauthlib"
|
||||||
|
version = "3.2.2"
|
||||||
|
description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"},
|
||||||
|
{file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
rsa = ["cryptography (>=3.0.0)"]
|
||||||
|
signals = ["blinker (>=1.4.0)"]
|
||||||
|
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "onnxruntime"
|
name = "onnxruntime"
|
||||||
version = "1.16.2"
|
version = "1.16.2"
|
||||||
|
@ -1614,6 +1739,101 @@ typing-extensions = ">=4.5,<5"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry-api"
|
||||||
|
version = "1.21.0"
|
||||||
|
description = "OpenTelemetry Python API"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "opentelemetry_api-1.21.0-py3-none-any.whl", hash = "sha256:4bb86b28627b7e41098f0e93280fe4892a1abed1b79a19aec6f928f39b17dffb"},
|
||||||
|
{file = "opentelemetry_api-1.21.0.tar.gz", hash = "sha256:d6185fd5043e000075d921822fd2d26b953eba8ca21b1e2fa360dd46a7686316"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
deprecated = ">=1.2.6"
|
||||||
|
importlib-metadata = ">=6.0,<7.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry-exporter-otlp-proto-common"
|
||||||
|
version = "1.21.0"
|
||||||
|
description = "OpenTelemetry Protobuf encoding"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "opentelemetry_exporter_otlp_proto_common-1.21.0-py3-none-any.whl", hash = "sha256:97b1022b38270ec65d11fbfa348e0cd49d12006485c2321ea3b1b7037d42b6ec"},
|
||||||
|
{file = "opentelemetry_exporter_otlp_proto_common-1.21.0.tar.gz", hash = "sha256:61db274d8a68d636fb2ec2a0f281922949361cdd8236e25ff5539edf942b3226"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""}
|
||||||
|
opentelemetry-proto = "1.21.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry-exporter-otlp-proto-grpc"
|
||||||
|
version = "1.21.0"
|
||||||
|
description = "OpenTelemetry Collector Protobuf over gRPC Exporter"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "opentelemetry_exporter_otlp_proto_grpc-1.21.0-py3-none-any.whl", hash = "sha256:ab37c63d6cb58d6506f76d71d07018eb1f561d83e642a8f5aa53dddf306087a4"},
|
||||||
|
{file = "opentelemetry_exporter_otlp_proto_grpc-1.21.0.tar.gz", hash = "sha256:a497c5611245a2d17d9aa1e1cbb7ab567843d53231dcc844a62cea9f0924ffa7"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
backoff = {version = ">=1.10.0,<3.0.0", markers = "python_version >= \"3.7\""}
|
||||||
|
deprecated = ">=1.2.6"
|
||||||
|
googleapis-common-protos = ">=1.52,<2.0"
|
||||||
|
grpcio = ">=1.0.0,<2.0.0"
|
||||||
|
opentelemetry-api = ">=1.15,<2.0"
|
||||||
|
opentelemetry-exporter-otlp-proto-common = "1.21.0"
|
||||||
|
opentelemetry-proto = "1.21.0"
|
||||||
|
opentelemetry-sdk = ">=1.21.0,<1.22.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
test = ["pytest-grpc"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry-proto"
|
||||||
|
version = "1.21.0"
|
||||||
|
description = "OpenTelemetry Python Proto"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "opentelemetry_proto-1.21.0-py3-none-any.whl", hash = "sha256:32fc4248e83eebd80994e13963e683f25f3b443226336bb12b5b6d53638f50ba"},
|
||||||
|
{file = "opentelemetry_proto-1.21.0.tar.gz", hash = "sha256:7d5172c29ed1b525b5ecf4ebe758c7138a9224441b3cfe683d0a237c33b1941f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
protobuf = ">=3.19,<5.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry-sdk"
|
||||||
|
version = "1.21.0"
|
||||||
|
description = "OpenTelemetry Python SDK"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "opentelemetry_sdk-1.21.0-py3-none-any.whl", hash = "sha256:9fe633243a8c655fedace3a0b89ccdfc654c0290ea2d8e839bd5db3131186f73"},
|
||||||
|
{file = "opentelemetry_sdk-1.21.0.tar.gz", hash = "sha256:3ec8cd3020328d6bc5c9991ccaf9ae820ccb6395a5648d9a95d3ec88275b8879"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
opentelemetry-api = "1.21.0"
|
||||||
|
opentelemetry-semantic-conventions = "0.42b0"
|
||||||
|
typing-extensions = ">=3.7.4"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opentelemetry-semantic-conventions"
|
||||||
|
version = "0.42b0"
|
||||||
|
description = "OpenTelemetry Semantic Conventions"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "opentelemetry_semantic_conventions-0.42b0-py3-none-any.whl", hash = "sha256:5cd719cbfec448af658860796c5d0fcea2fdf0945a2bed2363f42cb1ee39f526"},
|
||||||
|
{file = "opentelemetry_semantic_conventions-0.42b0.tar.gz", hash = "sha256:44ae67a0a3252a05072877857e5cc1242c98d4cf12870159f1a94bec800d38ec"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "orjson"
|
name = "orjson"
|
||||||
version = "3.9.10"
|
version = "3.9.10"
|
||||||
|
@ -1953,6 +2173,31 @@ files = [
|
||||||
{file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"},
|
{file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyasn1"
|
||||||
|
version = "0.5.0"
|
||||||
|
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
|
||||||
|
optional = false
|
||||||
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
|
||||||
|
files = [
|
||||||
|
{file = "pyasn1-0.5.0-py2.py3-none-any.whl", hash = "sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57"},
|
||||||
|
{file = "pyasn1-0.5.0.tar.gz", hash = "sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyasn1-modules"
|
||||||
|
version = "0.3.0"
|
||||||
|
description = "A collection of ASN.1-based protocols modules"
|
||||||
|
optional = false
|
||||||
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
|
||||||
|
files = [
|
||||||
|
{file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"},
|
||||||
|
{file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pyasn1 = ">=0.4.6,<0.6.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic"
|
name = "pydantic"
|
||||||
version = "2.4.2"
|
version = "2.4.2"
|
||||||
|
@ -2488,6 +2733,24 @@ urllib3 = ">=1.21.1,<3"
|
||||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "requests-oauthlib"
|
||||||
|
version = "1.3.1"
|
||||||
|
description = "OAuthlib authentication support for Requests."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||||
|
files = [
|
||||||
|
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
|
||||||
|
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
oauthlib = ">=3.0.0"
|
||||||
|
requests = ">=2.0.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rope"
|
name = "rope"
|
||||||
version = "0.22.0"
|
version = "0.22.0"
|
||||||
|
@ -2502,6 +2765,20 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["build", "pytest", "pytest-timeout"]
|
dev = ["build", "pytest", "pytest-timeout"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rsa"
|
||||||
|
version = "4.9"
|
||||||
|
description = "Pure-Python RSA implementation"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6,<4"
|
||||||
|
files = [
|
||||||
|
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
|
||||||
|
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pyasn1 = ">=0.1.3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "runpod"
|
name = "runpod"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
|
@ -2909,6 +3186,20 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
mpmath = ">=0.19"
|
mpmath = ">=0.19"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tenacity"
|
||||||
|
version = "8.2.3"
|
||||||
|
description = "Retry code until it succeeds"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"},
|
||||||
|
{file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
doc = ["reno", "sphinx", "tornado (>=4.5)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thefuzz"
|
name = "thefuzz"
|
||||||
version = "0.20.0"
|
version = "0.20.0"
|
||||||
|
@ -3415,20 +3706,19 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "urllib3"
|
name = "urllib3"
|
||||||
version = "2.0.7"
|
version = "1.26.18"
|
||||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
|
||||||
files = [
|
files = [
|
||||||
{file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"},
|
{file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"},
|
||||||
{file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"},
|
{file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"]
|
brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
|
||||||
secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"]
|
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
|
||||||
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
|
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
|
||||||
zstd = ["zstandard (>=0.18.0)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "uvicorn"
|
name = "uvicorn"
|
||||||
|
@ -3587,6 +3877,22 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
anyio = ">=3.0.0"
|
anyio = ">=3.0.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "websocket-client"
|
||||||
|
version = "1.6.4"
|
||||||
|
description = "WebSocket client for Python with low level API options"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "websocket-client-1.6.4.tar.gz", hash = "sha256:b3324019b3c28572086c4a319f91d1dcd44e6e11cd340232978c684a7650d0df"},
|
||||||
|
{file = "websocket_client-1.6.4-py3-none-any.whl", hash = "sha256:084072e0a7f5f347ef2ac3d8698a5e0b4ffbfcab607628cadabc650fc9a83a24"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"]
|
||||||
|
optional = ["python-socks", "wsaccel"]
|
||||||
|
test = ["websockets"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "websockets"
|
name = "websockets"
|
||||||
version = "11.0.3"
|
version = "11.0.3"
|
||||||
|
@ -3832,7 +4138,22 @@ files = [
|
||||||
idna = ">=2.0"
|
idna = ">=2.0"
|
||||||
multidict = ">=4.0"
|
multidict = ">=4.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zipp"
|
||||||
|
version = "3.17.0"
|
||||||
|
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
|
||||||
|
{file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||||
|
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<4.0"
|
python-versions = ">=3.10,<4.0"
|
||||||
content-hash = "13dc0c939ece1591caa09211c5a29a839cb63b5a921797ab225fc723b66e0d67"
|
content-hash = "8d77eeb6bba3c389345f461840b5257716a397e3ecaebc735a26b06e27361a1a"
|
||||||
|
|
|
@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "talemate"
|
name = "talemate"
|
||||||
version = "0.12.0"
|
version = "0.13.0"
|
||||||
description = "AI-backed roleplay and narrative tools"
|
description = "AI-backed roleplay and narrative tools"
|
||||||
authors = ["FinalWombat"]
|
authors = ["FinalWombat"]
|
||||||
license = "GNU Affero General Public License v3.0"
|
license = "GNU Affero General Public License v3.0"
|
||||||
|
@ -39,9 +39,9 @@ thefuzz = ">=0.20.0"
|
||||||
tiktoken = ">=0.5.1"
|
tiktoken = ">=0.5.1"
|
||||||
|
|
||||||
# ChromaDB
|
# ChromaDB
|
||||||
chromadb = ">=0.4,<1"
|
chromadb = ">=0.4.17,<1"
|
||||||
InstructorEmbedding = "^1.0.1"
|
InstructorEmbedding = "^1.0.1"
|
||||||
torch = ">=2.0.0, !=2.0.1"
|
torch = ">=2.1.0"
|
||||||
sentence-transformers="^2.2.2"
|
sentence-transformers="^2.2.2"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
|
|
|
@ -9,7 +9,7 @@ REM activate the virtual environment
|
||||||
call talemate_env\Scripts\activate
|
call talemate_env\Scripts\activate
|
||||||
|
|
||||||
REM install poetry
|
REM install poetry
|
||||||
python -m pip install poetry "rapidfuzz>=3" -U
|
python -m pip install "poetry==1.7.1" "rapidfuzz>=3" -U
|
||||||
|
|
||||||
REM use poetry to install dependencies
|
REM use poetry to install dependencies
|
||||||
python -m poetry install
|
python -m poetry install
|
||||||
|
|
|
@ -2,4 +2,4 @@ from .agents import Agent
|
||||||
from .client import TextGeneratorWebuiClient
|
from .client import TextGeneratorWebuiClient
|
||||||
from .tale_mate import *
|
from .tale_mate import *
|
||||||
|
|
||||||
VERSION = "0.12.0"
|
VERSION = "0.13.0"
|
||||||
|
|
|
@ -328,9 +328,13 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||||
model_name=instructor_model, device=instructor_device
|
model_name=instructor_model, device=instructor_device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log.info("chromadb", status="embedding function ready")
|
||||||
|
|
||||||
self.db = self.db_client.get_or_create_collection(
|
self.db = self.db_client.get_or_create_collection(
|
||||||
collection_name, embedding_function=ef
|
collection_name, embedding_function=ef
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log.info("chromadb", status="instructor db ready")
|
||||||
else:
|
else:
|
||||||
log.info("chromadb", status="using default embeddings")
|
log.info("chromadb", status="using default embeddings")
|
||||||
self.db = self.db_client.get_or_create_collection(collection_name)
|
self.db = self.db_client.get_or_create_collection(collection_name)
|
||||||
|
@ -461,6 +465,7 @@ class ChromaDBMemoryAgent(MemoryAgent):
|
||||||
|
|
||||||
#import json
|
#import json
|
||||||
#print(json.dumps(_results["ids"], indent=2))
|
#print(json.dumps(_results["ids"], indent=2))
|
||||||
|
#print(json.dumps(_results["distances"], indent=2))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||||
import structlog
|
import structlog
|
||||||
|
import random
|
||||||
import talemate.util as util
|
import talemate.util as util
|
||||||
from talemate.emit import emit
|
from talemate.emit import emit
|
||||||
import talemate.emit.async_signals
|
import talemate.emit.async_signals
|
||||||
|
@ -9,14 +10,23 @@ from talemate.prompts import Prompt
|
||||||
from talemate.agents.base import set_processing, Agent, AgentAction, AgentActionConfig
|
from talemate.agents.base import set_processing, Agent, AgentAction, AgentActionConfig
|
||||||
from talemate.agents.world_state import TimePassageEmission
|
from talemate.agents.world_state import TimePassageEmission
|
||||||
from talemate.scene_message import NarratorMessage
|
from talemate.scene_message import NarratorMessage
|
||||||
|
from talemate.events import GameLoopActorIterEvent
|
||||||
import talemate.client as client
|
import talemate.client as client
|
||||||
|
|
||||||
from .registry import register
|
from .registry import register
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from talemate.tale_mate import Actor, Player, Character
|
||||||
|
|
||||||
log = structlog.get_logger("talemate.agents.narrator")
|
log = structlog.get_logger("talemate.agents.narrator")
|
||||||
|
|
||||||
@register()
|
@register()
|
||||||
class NarratorAgent(Agent):
|
class NarratorAgent(Agent):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Handles narration of the story
|
||||||
|
"""
|
||||||
|
|
||||||
agent_type = "narrator"
|
agent_type = "narrator"
|
||||||
verbose_name = "Narrator"
|
verbose_name = "Narrator"
|
||||||
|
|
||||||
|
@ -27,31 +37,78 @@ class NarratorAgent(Agent):
|
||||||
):
|
):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
|
# agent actions
|
||||||
|
|
||||||
self.actions = {
|
self.actions = {
|
||||||
"narrate_time_passage": AgentAction(enabled=False, label="Narrate Time Passage", description="Whenever you indicate passage of time, narrate right after"),
|
"narrate_time_passage": AgentAction(enabled=True, label="Narrate Time Passage", description="Whenever you indicate passage of time, narrate right after"),
|
||||||
|
"narrate_dialogue": AgentAction(
|
||||||
|
enabled=True,
|
||||||
|
label="Narrate Dialogue",
|
||||||
|
description="Narrator will get a chance to narrate after every line of dialogue",
|
||||||
|
config = {
|
||||||
|
"ai_dialog": AgentActionConfig(
|
||||||
|
type="number",
|
||||||
|
label="AI Dialogue",
|
||||||
|
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
|
||||||
|
value=0.3,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.1,
|
||||||
|
),
|
||||||
|
"player_dialog": AgentActionConfig(
|
||||||
|
type="number",
|
||||||
|
label="Player Dialogue",
|
||||||
|
description="Chance to narrate after every line of dialogue, 1 = always, 0 = never",
|
||||||
|
value=0.3,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.1,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def clean_result(self, result):
|
def clean_result(self, result):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Cleans the result of a narration
|
||||||
|
"""
|
||||||
|
|
||||||
result = result.strip().strip(":").strip()
|
result = result.strip().strip(":").strip()
|
||||||
|
|
||||||
if "#" in result:
|
if "#" in result:
|
||||||
result = result.split("#")[0]
|
result = result.split("#")[0]
|
||||||
|
|
||||||
|
character_names = [c.name for c in self.scene.get_characters()]
|
||||||
|
|
||||||
|
|
||||||
cleaned = []
|
cleaned = []
|
||||||
for line in result.split("\n"):
|
for line in result.split("\n"):
|
||||||
if ":" in line.strip():
|
for character_name in character_names:
|
||||||
break
|
if line.startswith(f"{character_name}:"):
|
||||||
|
break
|
||||||
cleaned.append(line)
|
cleaned.append(line)
|
||||||
|
|
||||||
return "\n".join(cleaned)
|
result = "\n".join(cleaned)
|
||||||
|
#result = util.strip_partial_sentences(result)
|
||||||
|
return result
|
||||||
|
|
||||||
def connect(self, scene):
|
def connect(self, scene):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Connect to signals
|
||||||
|
"""
|
||||||
|
|
||||||
super().connect(scene)
|
super().connect(scene)
|
||||||
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
|
talemate.emit.async_signals.get("agent.world_state.time").connect(self.on_time_passage)
|
||||||
|
talemate.emit.async_signals.get("game_loop_actor_iter").connect(self.on_dialog)
|
||||||
|
|
||||||
async def on_time_passage(self, event:TimePassageEmission):
|
async def on_time_passage(self, event:TimePassageEmission):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Handles time passage narration, if enabled
|
||||||
|
"""
|
||||||
|
|
||||||
if not self.actions["narrate_time_passage"].enabled:
|
if not self.actions["narrate_time_passage"].enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -59,6 +116,31 @@ class NarratorAgent(Agent):
|
||||||
narrator_message = NarratorMessage(response, source=f"narrate_time_passage:{event.duration};{event.narrative}")
|
narrator_message = NarratorMessage(response, source=f"narrate_time_passage:{event.duration};{event.narrative}")
|
||||||
emit("narrator", narrator_message)
|
emit("narrator", narrator_message)
|
||||||
self.scene.push_history(narrator_message)
|
self.scene.push_history(narrator_message)
|
||||||
|
|
||||||
|
async def on_dialog(self, event:GameLoopActorIterEvent):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Handles dialogue narration, if enabled
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.actions["narrate_dialogue"].enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
narrate_on_ai_chance = random.random() < self.actions["narrate_dialogue"].config["ai_dialog"].value
|
||||||
|
narrate_on_player_chance = random.random() < self.actions["narrate_dialogue"].config["player_dialog"].value
|
||||||
|
|
||||||
|
log.debug("narrate on dialog", narrate_on_ai_chance=narrate_on_ai_chance, narrate_on_player_chance=narrate_on_player_chance)
|
||||||
|
|
||||||
|
if event.actor.character.is_player and not narrate_on_player_chance:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not event.actor.character.is_player and not narrate_on_ai_chance:
|
||||||
|
return
|
||||||
|
|
||||||
|
response = await self.narrate_after_dialogue(event.actor.character)
|
||||||
|
narrator_message = NarratorMessage(response, source=f"narrate_dialogue:{event.actor.character.name}")
|
||||||
|
emit("narrator", narrator_message)
|
||||||
|
self.scene.push_history(narrator_message)
|
||||||
|
|
||||||
@set_processing
|
@set_processing
|
||||||
async def narrate_scene(self):
|
async def narrate_scene(self):
|
||||||
|
@ -155,8 +237,9 @@ class NarratorAgent(Agent):
|
||||||
"as_narrative": as_narrative,
|
"as_narrative": as_narrative,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
log.info("narrate_query", response=response)
|
||||||
response = self.clean_result(response.strip())
|
response = self.clean_result(response.strip())
|
||||||
|
log.info("narrate_query (after clean)", response=response)
|
||||||
if as_narrative:
|
if as_narrative:
|
||||||
response = f"*{response}*"
|
response = f"*{response}*"
|
||||||
|
|
||||||
|
@ -265,4 +348,30 @@ class NarratorAgent(Agent):
|
||||||
response = self.clean_result(response.strip())
|
response = self.clean_result(response.strip())
|
||||||
response = f"*{response}*"
|
response = f"*{response}*"
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@set_processing
|
||||||
|
async def narrate_after_dialogue(self, character:Character):
|
||||||
|
"""
|
||||||
|
Narrate after a line of dialogue
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await Prompt.request(
|
||||||
|
"narrator.narrate-after-dialogue",
|
||||||
|
self.client,
|
||||||
|
"narrate",
|
||||||
|
vars = {
|
||||||
|
"scene": self.scene,
|
||||||
|
"max_tokens": self.client.max_token_length,
|
||||||
|
"character": character,
|
||||||
|
"last_line": str(self.scene.history[-1])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("narrate_after_dialogue", response=response)
|
||||||
|
|
||||||
|
response = self.clean_result(response.strip().strip("*"))
|
||||||
|
response = f"*{response}*"
|
||||||
|
|
||||||
return response
|
return response
|
|
@ -8,6 +8,7 @@ import talemate.util as util
|
||||||
from talemate.prompts import Prompt
|
from talemate.prompts import Prompt
|
||||||
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
from talemate.scene_message import DirectorMessage, TimePassageMessage
|
||||||
from talemate.emit import emit
|
from talemate.emit import emit
|
||||||
|
from talemate.events import GameLoopEvent
|
||||||
|
|
||||||
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
|
from .base import Agent, set_processing, AgentAction, AgentActionConfig, AgentEmission
|
||||||
from .registry import register
|
from .registry import register
|
||||||
|
@ -16,9 +17,6 @@ import structlog
|
||||||
import isodate
|
import isodate
|
||||||
import time
|
import time
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from talemate.agents.conversation import ConversationAgentEmission
|
|
||||||
|
|
||||||
|
|
||||||
log = structlog.get_logger("talemate.agents.world_state")
|
log = structlog.get_logger("talemate.agents.world_state")
|
||||||
|
|
||||||
|
@ -74,7 +72,7 @@ class WorldStateAgent(Agent):
|
||||||
|
|
||||||
def connect(self, scene):
|
def connect(self, scene):
|
||||||
super().connect(scene)
|
super().connect(scene)
|
||||||
talemate.emit.async_signals.get("agent.conversation.generated").connect(self.on_conversation_generated)
|
talemate.emit.async_signals.get("game_loop").connect(self.on_game_loop)
|
||||||
|
|
||||||
async def advance_time(self, duration:str, narrative:str=None):
|
async def advance_time(self, duration:str, narrative:str=None):
|
||||||
"""
|
"""
|
||||||
|
@ -96,7 +94,7 @@ class WorldStateAgent(Agent):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def on_conversation_generated(self, emission:ConversationAgentEmission):
|
async def on_game_loop(self, emission:GameLoopEvent):
|
||||||
"""
|
"""
|
||||||
Called when a conversation is generated
|
Called when a conversation is generated
|
||||||
"""
|
"""
|
||||||
|
@ -104,8 +102,7 @@ class WorldStateAgent(Agent):
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
for _ in emission.generation:
|
await self.update_world_state()
|
||||||
await self.update_world_state()
|
|
||||||
|
|
||||||
|
|
||||||
async def update_world_state(self):
|
async def update_world_state(self):
|
||||||
|
@ -230,7 +227,7 @@ class WorldStateAgent(Agent):
|
||||||
):
|
):
|
||||||
|
|
||||||
response = await Prompt.request(
|
response = await Prompt.request(
|
||||||
"world_state.analyze-and-follow-instruction",
|
"world_state.analyze-text-and-follow-instruction",
|
||||||
self.client,
|
self.client,
|
||||||
"analyze_freeform",
|
"analyze_freeform",
|
||||||
vars = {
|
vars = {
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
import os
|
||||||
from talemate.client.openai import OpenAIClient
|
from talemate.client.openai import OpenAIClient
|
||||||
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
from talemate.client.registry import CLIENT_CLASSES, get_client_class, register
|
||||||
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
from talemate.client.textgenwebui import TextGeneratorWebuiClient
|
||||||
import talemate.client.runpod
|
from talemate.client.lmstudio import LMStudioClient
|
||||||
|
import talemate.client.runpod
|
||||||
|
|
349
src/talemate/client/base.py
Normal file
349
src/talemate/client/base.py
Normal file
|
@ -0,0 +1,349 @@
|
||||||
|
"""
|
||||||
|
A unified client base, based on the openai API
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
import logging
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from talemate.emit import emit
|
||||||
|
import talemate.instance as instance
|
||||||
|
import talemate.client.presets as presets
|
||||||
|
import talemate.client.system_prompts as system_prompts
|
||||||
|
import talemate.util as util
|
||||||
|
from talemate.client.context import client_context_attribute
|
||||||
|
from talemate.client.model_prompts import model_prompt
|
||||||
|
|
||||||
|
|
||||||
|
# Set up logging level for httpx to WARNING to suppress debug logs.
|
||||||
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
REMOTE_SERVICES = [
|
||||||
|
# TODO: runpod.py should add this to the list
|
||||||
|
".runpod.net"
|
||||||
|
]
|
||||||
|
|
||||||
|
STOPPING_STRINGS = ["<|im_end|>", "</s>"]
|
||||||
|
|
||||||
|
class ClientBase:
|
||||||
|
|
||||||
|
api_url: str
|
||||||
|
model_name: str
|
||||||
|
name:str = None
|
||||||
|
enabled: bool = True
|
||||||
|
current_status: str = None
|
||||||
|
max_token_length: int = 4096
|
||||||
|
randomizable_inference_parameters: list[str] = ["temperature"]
|
||||||
|
processing: bool = False
|
||||||
|
connected: bool = False
|
||||||
|
conversation_retries: int = 5
|
||||||
|
|
||||||
|
client_type = "base"
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_url: str,
|
||||||
|
name = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.api_url = api_url
|
||||||
|
self.name = name or self.client_type
|
||||||
|
self.log = structlog.get_logger(f"client.{self.client_type}")
|
||||||
|
self.set_client()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.client_type}Client[{self.api_url}][{self.model_name or ''}]"
|
||||||
|
|
||||||
|
def set_client(self):
|
||||||
|
self.client = AsyncOpenAI(base_url=self.api_url, api_key="sk-1111")
|
||||||
|
|
||||||
|
def prompt_template(self, sys_msg, prompt):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Applies the appropriate prompt template for the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.model_name:
|
||||||
|
self.log.warning("prompt template not applied", reason="no model loaded")
|
||||||
|
return f"{sys_msg}\n{prompt}"
|
||||||
|
|
||||||
|
return model_prompt(self.model_name, sys_msg, prompt)
|
||||||
|
|
||||||
|
def reconfigure(self, **kwargs):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Reconfigures the client.
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
|
||||||
|
- api_url: the API URL to use
|
||||||
|
- max_token_length: the max token length to use
|
||||||
|
- enabled: whether the client is enabled
|
||||||
|
"""
|
||||||
|
|
||||||
|
if "api_url" in kwargs:
|
||||||
|
self.api_url = kwargs["api_url"]
|
||||||
|
|
||||||
|
if "max_token_length" in kwargs:
|
||||||
|
self.max_token_length = kwargs["max_token_length"]
|
||||||
|
|
||||||
|
if "enabled" in kwargs:
|
||||||
|
self.enabled = bool(kwargs["enabled"])
|
||||||
|
|
||||||
|
|
||||||
|
def toggle_disabled_if_remote(self):
|
||||||
|
|
||||||
|
"""
|
||||||
|
If the client is targeting a remote recognized service, this
|
||||||
|
will disable the client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for service in REMOTE_SERVICES:
|
||||||
|
if service in self.api_url:
|
||||||
|
if self.enabled:
|
||||||
|
self.log.warn("remote service unreachable, disabling client", client=self.name)
|
||||||
|
self.enabled = False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_message(self, kind: str) -> str:
|
||||||
|
|
||||||
|
"""
|
||||||
|
Returns the appropriate system message for the given kind of generation
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
- kind: the kind of generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: make extensible
|
||||||
|
|
||||||
|
if "narrate" in kind:
|
||||||
|
return system_prompts.NARRATOR
|
||||||
|
if "story" in kind:
|
||||||
|
return system_prompts.NARRATOR
|
||||||
|
if "director" in kind:
|
||||||
|
return system_prompts.DIRECTOR
|
||||||
|
if "create" in kind:
|
||||||
|
return system_prompts.CREATOR
|
||||||
|
if "roleplay" in kind:
|
||||||
|
return system_prompts.ROLEPLAY
|
||||||
|
if "conversation" in kind:
|
||||||
|
return system_prompts.ROLEPLAY
|
||||||
|
if "editor" in kind:
|
||||||
|
return system_prompts.EDITOR
|
||||||
|
if "world_state" in kind:
|
||||||
|
return system_prompts.WORLD_STATE
|
||||||
|
if "analyst" in kind:
|
||||||
|
return system_prompts.ANALYST
|
||||||
|
if "analyze" in kind:
|
||||||
|
return system_prompts.ANALYST
|
||||||
|
|
||||||
|
return system_prompts.BASIC
|
||||||
|
|
||||||
|
|
||||||
|
def emit_status(self, processing: bool = None):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Sets and emits the client status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if processing is not None:
|
||||||
|
self.processing = processing
|
||||||
|
|
||||||
|
if not self.enabled:
|
||||||
|
status = "disabled"
|
||||||
|
model_name = "Disabled"
|
||||||
|
elif not self.connected:
|
||||||
|
status = "error"
|
||||||
|
model_name = "Could not connect"
|
||||||
|
elif self.model_name:
|
||||||
|
status = "busy" if self.processing else "idle"
|
||||||
|
model_name = self.model_name
|
||||||
|
else:
|
||||||
|
model_name = "No model loaded"
|
||||||
|
status = "warning"
|
||||||
|
|
||||||
|
status_change = status != self.current_status
|
||||||
|
self.current_status = status
|
||||||
|
|
||||||
|
emit(
|
||||||
|
"client_status",
|
||||||
|
message=self.client_type,
|
||||||
|
id=self.name,
|
||||||
|
details=model_name,
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status_change:
|
||||||
|
instance.emit_agent_status_by_client(self)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_model_name(self):
|
||||||
|
models = await self.client.models.list()
|
||||||
|
try:
|
||||||
|
return models.data[0].id
|
||||||
|
except IndexError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def status(self):
|
||||||
|
"""
|
||||||
|
Send a request to the API to retrieve the loaded AI model name.
|
||||||
|
Raises an error if no model name is returned.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
if self.processing:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.enabled:
|
||||||
|
self.connected = False
|
||||||
|
self.emit_status()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.model_name = await self.get_model_name()
|
||||||
|
except Exception as e:
|
||||||
|
self.log.warning("client status error", e=e, client=self.name)
|
||||||
|
self.model_name = None
|
||||||
|
self.connected = False
|
||||||
|
self.toggle_disabled_if_remote()
|
||||||
|
self.emit_status()
|
||||||
|
return
|
||||||
|
|
||||||
|
self.connected = True
|
||||||
|
|
||||||
|
if not self.model_name or self.model_name == "None":
|
||||||
|
self.log.warning("client model not loaded", client=self)
|
||||||
|
self.emit_status()
|
||||||
|
return
|
||||||
|
|
||||||
|
self.emit_status()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_prompt_parameters(self, kind:str):
|
||||||
|
parameters = {}
|
||||||
|
self.tune_prompt_parameters(
|
||||||
|
presets.configure(parameters, kind, self.max_token_length),
|
||||||
|
kind
|
||||||
|
)
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||||
|
parameters["stream"] = False
|
||||||
|
if client_context_attribute("nuke_repetition") > 0.0 and self.jiggle_enabled_for(kind):
|
||||||
|
self.jiggle_randomness(parameters, offset=client_context_attribute("nuke_repetition"))
|
||||||
|
|
||||||
|
fn_tune_kind = getattr(self, f"tune_prompt_parameters_{kind}", None)
|
||||||
|
if fn_tune_kind:
|
||||||
|
fn_tune_kind(parameters)
|
||||||
|
|
||||||
|
def tune_prompt_parameters_conversation(self, parameters:dict):
|
||||||
|
conversation_context = client_context_attribute("conversation")
|
||||||
|
parameters["max_tokens"] = conversation_context.get("length", 96)
|
||||||
|
|
||||||
|
dialog_stopping_strings = [
|
||||||
|
f"{character}:" for character in conversation_context["other_characters"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if "extra_stopping_strings" in parameters:
|
||||||
|
parameters["extra_stopping_strings"] += dialog_stopping_strings
|
||||||
|
else:
|
||||||
|
parameters["extra_stopping_strings"] = dialog_stopping_strings
|
||||||
|
|
||||||
|
|
||||||
|
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Generates text from the given prompt and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.completions.create(prompt=prompt.strip(), **parameters)
|
||||||
|
return response.get("choices", [{}])[0].get("text", "")
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("generate error", e=e)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def send_prompt(
|
||||||
|
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Send a prompt to the AI and return its response.
|
||||||
|
:param prompt: The text prompt to send.
|
||||||
|
:return: The AI's response text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.emit_status(processing=True)
|
||||||
|
await self.status()
|
||||||
|
|
||||||
|
prompt_param = self.generate_prompt_parameters(kind)
|
||||||
|
|
||||||
|
finalized_prompt = self.prompt_template(self.get_system_message(kind), prompt).strip()
|
||||||
|
prompt_param = finalize(prompt_param)
|
||||||
|
|
||||||
|
token_length = self.count_tokens(finalized_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
time_start = time.time()
|
||||||
|
extra_stopping_strings = prompt_param.pop("extra_stopping_strings", [])
|
||||||
|
|
||||||
|
self.log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length, parameters=prompt_param)
|
||||||
|
response = await self.generate(finalized_prompt, prompt_param, kind)
|
||||||
|
|
||||||
|
time_end = time.time()
|
||||||
|
|
||||||
|
# stopping strings sometimes get appended to the end of the response anyways
|
||||||
|
# split the response by the first stopping string and take the first part
|
||||||
|
|
||||||
|
|
||||||
|
for stopping_string in STOPPING_STRINGS + extra_stopping_strings:
|
||||||
|
if stopping_string in response:
|
||||||
|
response = response.split(stopping_string)[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
emit("prompt_sent", data={
|
||||||
|
"kind": kind,
|
||||||
|
"prompt": finalized_prompt,
|
||||||
|
"response": response,
|
||||||
|
"prompt_tokens": token_length,
|
||||||
|
"response_tokens": self.count_tokens(response),
|
||||||
|
"time": time_end - time_start,
|
||||||
|
})
|
||||||
|
|
||||||
|
return response
|
||||||
|
finally:
|
||||||
|
self.emit_status(processing=False)
|
||||||
|
|
||||||
|
def count_tokens(self, content:str):
|
||||||
|
return util.count_tokens(content)
|
||||||
|
|
||||||
|
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
|
||||||
|
"""
|
||||||
|
adjusts temperature and repetition_penalty
|
||||||
|
by random values using the base value as a center
|
||||||
|
"""
|
||||||
|
|
||||||
|
temp = prompt_config["temperature"]
|
||||||
|
min_offset = offset * 0.3
|
||||||
|
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||||
|
|
||||||
|
def jiggle_enabled_for(self, kind:str):
|
||||||
|
|
||||||
|
if kind in ["conversation", "story"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if kind.startswith("narrate"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
56
src/talemate/client/lmstudio.py
Normal file
56
src/talemate/client/lmstudio.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
from talemate.client.base import ClientBase
|
||||||
|
from talemate.client.registry import register
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
@register()
|
||||||
|
class LMStudioClient(ClientBase):
|
||||||
|
|
||||||
|
client_type = "lmstudio"
|
||||||
|
conversation_retries = 5
|
||||||
|
|
||||||
|
def set_client(self):
|
||||||
|
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
|
||||||
|
|
||||||
|
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||||
|
super().tune_prompt_parameters(parameters, kind)
|
||||||
|
|
||||||
|
keys = list(parameters.keys())
|
||||||
|
|
||||||
|
valid_keys = ["temperature", "top_p"]
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key not in valid_keys:
|
||||||
|
del parameters[key]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_model_name(self):
|
||||||
|
model_name = await super().get_model_name()
|
||||||
|
|
||||||
|
# model name comes back as a file path, so we need to extract the model name
|
||||||
|
# the path could be windows or linux so it needs to handle both backslash and forward slash
|
||||||
|
|
||||||
|
if model_name:
|
||||||
|
model_name = model_name.replace("\\", "/").split("/")[-1]
|
||||||
|
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Generates text from the given prompt and parameters.
|
||||||
|
"""
|
||||||
|
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||||
|
|
||||||
|
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model_name, messages=[human_message], **parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("generate error", e=e)
|
||||||
|
return ""
|
|
@ -1,10 +1,9 @@
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import time
|
import json
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
from talemate.client.base import ClientBase
|
||||||
from talemate.client.registry import register
|
from talemate.client.registry import register
|
||||||
from talemate.emit import emit
|
from talemate.emit import emit
|
||||||
from talemate.config import load_config
|
from talemate.config import load_config
|
||||||
|
@ -15,10 +14,9 @@ import tiktoken
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OpenAIClient",
|
"OpenAIClient",
|
||||||
]
|
]
|
||||||
|
|
||||||
log = structlog.get_logger("talemate")
|
log = structlog.get_logger("talemate")
|
||||||
|
|
||||||
def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
|
def num_tokens_from_messages(messages:list[dict], model:str="gpt-3.5-turbo-0613"):
|
||||||
"""Return the number of tokens used by a list of messages."""
|
"""Return the number of tokens used by a list of messages."""
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
@ -70,7 +68,7 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
@register()
|
@register()
|
||||||
class OpenAIClient:
|
class OpenAIClient(ClientBase):
|
||||||
"""
|
"""
|
||||||
OpenAI client for generating text.
|
OpenAI client for generating text.
|
||||||
"""
|
"""
|
||||||
|
@ -79,13 +77,10 @@ class OpenAIClient:
|
||||||
conversation_retries = 0
|
conversation_retries = 0
|
||||||
|
|
||||||
def __init__(self, model="gpt-4-1106-preview", **kwargs):
|
def __init__(self, model="gpt-4-1106-preview", **kwargs):
|
||||||
self.name = kwargs.get("name", "openai")
|
|
||||||
self.model_name = model
|
self.model_name = model
|
||||||
self.last_token_length = 0
|
|
||||||
self.max_token_length = 2048
|
|
||||||
self.processing = False
|
|
||||||
self.current_status = "idle"
|
|
||||||
self.config = load_config()
|
self.config = load_config()
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# if os.environ.get("OPENAI_API_KEY") is not set, look in the config file
|
# if os.environ.get("OPENAI_API_KEY") is not set, look in the config file
|
||||||
# and set it
|
# and set it
|
||||||
|
@ -94,7 +89,7 @@ class OpenAIClient:
|
||||||
if self.config.get("openai", {}).get("api_key"):
|
if self.config.get("openai", {}).get("api_key"):
|
||||||
os.environ["OPENAI_API_KEY"] = self.config["openai"]["api_key"]
|
os.environ["OPENAI_API_KEY"] = self.config["openai"]["api_key"]
|
||||||
|
|
||||||
self.set_client(model)
|
self.set_client()
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -123,12 +118,14 @@ class OpenAIClient:
|
||||||
status=status,
|
status=status,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_client(self, model:str, max_token_length:int=None):
|
def set_client(self, max_token_length:int=None):
|
||||||
|
|
||||||
if not self.openai_api_key:
|
if not self.openai_api_key:
|
||||||
log.error("No OpenAI API key set")
|
log.error("No OpenAI API key set")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
model = self.model_name
|
||||||
|
|
||||||
self.client = AsyncOpenAI()
|
self.client = AsyncOpenAI()
|
||||||
if model == "gpt-3.5-turbo":
|
if model == "gpt-3.5-turbo":
|
||||||
self.max_token_length = min(max_token_length or 4096, 4096)
|
self.max_token_length = min(max_token_length or 4096, 4096)
|
||||||
|
@ -144,89 +141,72 @@ class OpenAIClient:
|
||||||
def reconfigure(self, **kwargs):
|
def reconfigure(self, **kwargs):
|
||||||
if "model" in kwargs:
|
if "model" in kwargs:
|
||||||
self.model_name = kwargs["model"]
|
self.model_name = kwargs["model"]
|
||||||
self.set_client(self.model_name, kwargs.get("max_token_length"))
|
self.set_client(kwargs.get("max_token_length"))
|
||||||
|
|
||||||
|
def count_tokens(self, content: str):
|
||||||
|
return num_tokens_from_messages([{"content": content}], model=self.model_name)
|
||||||
|
|
||||||
async def status(self):
|
async def status(self):
|
||||||
self.emit_status()
|
self.emit_status()
|
||||||
|
|
||||||
def get_system_message(self, kind: str) -> str:
|
|
||||||
|
def prompt_template(self, system_message:str, prompt:str):
|
||||||
if "narrate" in kind:
|
|
||||||
return system_prompts.NARRATOR
|
|
||||||
if "story" in kind:
|
|
||||||
return system_prompts.NARRATOR
|
|
||||||
if "director" in kind:
|
|
||||||
return system_prompts.DIRECTOR
|
|
||||||
if "create" in kind:
|
|
||||||
return system_prompts.CREATOR
|
|
||||||
if "roleplay" in kind:
|
|
||||||
return system_prompts.ROLEPLAY
|
|
||||||
if "conversation" in kind:
|
|
||||||
return system_prompts.ROLEPLAY
|
|
||||||
if "editor" in kind:
|
|
||||||
return system_prompts.EDITOR
|
|
||||||
if "world_state" in kind:
|
|
||||||
return system_prompts.WORLD_STATE
|
|
||||||
if "analyst" in kind:
|
|
||||||
return system_prompts.ANALYST
|
|
||||||
if "analyze" in kind:
|
|
||||||
return system_prompts.ANALYST
|
|
||||||
|
|
||||||
return system_prompts.BASIC
|
|
||||||
|
|
||||||
async def send_prompt(
|
|
||||||
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
|
|
||||||
) -> str:
|
|
||||||
|
|
||||||
right = ""
|
|
||||||
opts = {}
|
|
||||||
|
|
||||||
# only gpt-4-1106-preview supports json_object response coersion
|
# only gpt-4-1106-preview supports json_object response coersion
|
||||||
supports_json_object = self.model_name in ["gpt-4-1106-preview"]
|
|
||||||
|
|
||||||
if "<|BOT|>" in prompt:
|
if "<|BOT|>" in prompt:
|
||||||
_, right = prompt.split("<|BOT|>", 1)
|
_, right = prompt.split("<|BOT|>", 1)
|
||||||
if right:
|
if right:
|
||||||
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
|
prompt = prompt.replace("<|BOT|>", "\nContinue this response: ")
|
||||||
expected_response = prompt.split("\nContinue this response: ")[1].strip()
|
|
||||||
if expected_response.startswith("{") and supports_json_object:
|
|
||||||
opts["response_format"] = {"type": "json_object"}
|
|
||||||
else:
|
else:
|
||||||
prompt = prompt.replace("<|BOT|>", "")
|
prompt = prompt.replace("<|BOT|>", "")
|
||||||
|
|
||||||
self.emit_status(processing=True)
|
return prompt
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
sys_message = {'role': 'system', 'content': self.get_system_message(kind)}
|
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||||
|
super().tune_prompt_parameters(parameters, kind)
|
||||||
|
|
||||||
human_message = {'role': 'user', 'content': prompt}
|
keys = list(parameters.keys())
|
||||||
|
|
||||||
|
valid_keys = ["temperature", "top_p"]
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key not in valid_keys:
|
||||||
|
del parameters[key]
|
||||||
|
|
||||||
log.debug("openai send", kind=kind, sys_message=sys_message, opts=opts)
|
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||||
|
|
||||||
time_start = time.time()
|
|
||||||
|
|
||||||
response = await self.client.chat.completions.create(model=self.model_name, messages=[sys_message, human_message], **opts)
|
|
||||||
|
|
||||||
time_end = time.time()
|
"""
|
||||||
|
Generates text from the given prompt and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
response = response.choices[0].message.content
|
# only gpt-4-1106-preview supports json_object response coersion
|
||||||
|
supports_json_object = self.model_name in ["gpt-4-1106-preview"]
|
||||||
|
right = None
|
||||||
|
try:
|
||||||
|
_, right = prompt.split("\nContinue this response: ")
|
||||||
|
expected_response = right.strip()
|
||||||
|
if expected_response.startswith("{") and supports_json_object:
|
||||||
|
parameters["response_format"] = {"type": "json_object"}
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
if right and response.startswith(right):
|
human_message = {'role': 'user', 'content': prompt.strip()}
|
||||||
response = response[len(right):].strip()
|
system_message = {'role': 'system', 'content': self.get_system_message(kind)}
|
||||||
|
|
||||||
|
self.log.debug("generate", prompt=prompt[:128]+" ...", parameters=parameters)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model_name, messages=[system_message, human_message], **parameters
|
||||||
|
)
|
||||||
|
|
||||||
if kind == "conversation":
|
response = response.choices[0].message.content
|
||||||
response = response.replace("\n", " ").strip()
|
|
||||||
|
if right and response.startswith(right):
|
||||||
log.debug("openai response", response=response)
|
response = response[len(right):].strip()
|
||||||
|
|
||||||
emit("prompt_sent", data={
|
return response
|
||||||
"kind": kind,
|
|
||||||
"prompt": prompt,
|
except Exception as e:
|
||||||
"response": response,
|
self.log.error("generate error", e=e)
|
||||||
"prompt_tokens": num_tokens_from_messages([sys_message, human_message], model=self.model_name),
|
return ""
|
||||||
"response_tokens": num_tokens_from_messages([{"role": "assistant", "content": response}], model=self.model_name),
|
|
||||||
"time": time_end - time_start,
|
|
||||||
})
|
|
||||||
|
|
||||||
self.emit_status(processing=False)
|
|
||||||
return response
|
|
163
src/talemate/client/presets.py
Normal file
163
src/talemate/client/presets.py
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
__all__ = [
|
||||||
|
"configure",
|
||||||
|
"set_max_tokens",
|
||||||
|
"set_preset",
|
||||||
|
"preset_for_kind",
|
||||||
|
"max_tokens_for_kind",
|
||||||
|
"PRESET_TALEMATE_CONVERSATION",
|
||||||
|
"PRESET_TALEMATE_CREATOR",
|
||||||
|
"PRESET_LLAMA_PRECISE",
|
||||||
|
"PRESET_DIVINE_INTELLECT",
|
||||||
|
"PRESET_SIMPLE_1",
|
||||||
|
]
|
||||||
|
|
||||||
|
PRESET_TALEMATE_CONVERSATION = {
|
||||||
|
"temperature": 0.65,
|
||||||
|
"top_p": 0.47,
|
||||||
|
"top_k": 42,
|
||||||
|
"repetition_penalty": 1.18,
|
||||||
|
"repetition_penalty_range": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRESET_TALEMATE_CREATOR = {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"top_k": 20,
|
||||||
|
"repetition_penalty": 1.15,
|
||||||
|
"repetition_penalty_range": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRESET_LLAMA_PRECISE = {
|
||||||
|
'temperature': 0.7,
|
||||||
|
'top_p': 0.1,
|
||||||
|
'top_k': 40,
|
||||||
|
'repetition_penalty': 1.18,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRESET_DIVINE_INTELLECT = {
|
||||||
|
'temperature': 1.31,
|
||||||
|
'top_p': 0.14,
|
||||||
|
'top_k': 49,
|
||||||
|
"repetition_penalty_range": 1024,
|
||||||
|
'repetition_penalty': 1.17,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRESET_SIMPLE_1 = {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"top_k": 20,
|
||||||
|
"repetition_penalty": 1.15,
|
||||||
|
}
|
||||||
|
|
||||||
|
def configure(config:dict, kind:str, total_budget:int):
|
||||||
|
"""
|
||||||
|
Sets the config based on the kind of text to generate.
|
||||||
|
"""
|
||||||
|
set_preset(config, kind)
|
||||||
|
set_max_tokens(config, kind, total_budget)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def set_max_tokens(config:dict, kind:str, total_budget:int):
|
||||||
|
"""
|
||||||
|
Sets the max_tokens in the config based on the kind of text to generate.
|
||||||
|
"""
|
||||||
|
config["max_tokens"] = max_tokens_for_kind(kind, total_budget)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def set_preset(config:dict, kind:str):
|
||||||
|
"""
|
||||||
|
Sets the preset in the config based on the kind of text to generate.
|
||||||
|
"""
|
||||||
|
config.update(preset_for_kind(kind))
|
||||||
|
|
||||||
|
def preset_for_kind(kind: str):
|
||||||
|
if kind == "conversation":
|
||||||
|
return PRESET_TALEMATE_CONVERSATION
|
||||||
|
elif kind == "conversation_old":
|
||||||
|
return PRESET_TALEMATE_CONVERSATION # Assuming old conversation uses the same preset
|
||||||
|
elif kind == "conversation_long":
|
||||||
|
return PRESET_TALEMATE_CONVERSATION # Assuming long conversation uses the same preset
|
||||||
|
elif kind == "conversation_select_talking_actor":
|
||||||
|
return PRESET_TALEMATE_CONVERSATION # Assuming select talking actor uses the same preset
|
||||||
|
elif kind == "summarize":
|
||||||
|
return PRESET_LLAMA_PRECISE
|
||||||
|
elif kind == "analyze":
|
||||||
|
return PRESET_SIMPLE_1
|
||||||
|
elif kind == "analyze_creative":
|
||||||
|
return PRESET_DIVINE_INTELLECT
|
||||||
|
elif kind == "analyze_long":
|
||||||
|
return PRESET_SIMPLE_1 # Assuming long analysis uses the same preset as simple
|
||||||
|
elif kind == "analyze_freeform":
|
||||||
|
return PRESET_LLAMA_PRECISE
|
||||||
|
elif kind == "analyze_freeform_short":
|
||||||
|
return PRESET_LLAMA_PRECISE # Assuming short freeform analysis uses the same preset as precise
|
||||||
|
elif kind == "narrate":
|
||||||
|
return PRESET_LLAMA_PRECISE
|
||||||
|
elif kind == "story":
|
||||||
|
return PRESET_DIVINE_INTELLECT
|
||||||
|
elif kind == "create":
|
||||||
|
return PRESET_TALEMATE_CREATOR
|
||||||
|
elif kind == "create_concise":
|
||||||
|
return PRESET_TALEMATE_CREATOR # Assuming concise creation uses the same preset as creator
|
||||||
|
elif kind == "create_precise":
|
||||||
|
return PRESET_LLAMA_PRECISE
|
||||||
|
elif kind == "director":
|
||||||
|
return PRESET_SIMPLE_1
|
||||||
|
elif kind == "director_short":
|
||||||
|
return PRESET_SIMPLE_1 # Assuming short direction uses the same preset as simple
|
||||||
|
elif kind == "director_yesno":
|
||||||
|
return PRESET_SIMPLE_1 # Assuming yes/no direction uses the same preset as simple
|
||||||
|
elif kind == "edit_dialogue":
|
||||||
|
return PRESET_DIVINE_INTELLECT
|
||||||
|
elif kind == "edit_add_detail":
|
||||||
|
return PRESET_DIVINE_INTELLECT # Assuming adding detail uses the same preset as divine intellect
|
||||||
|
elif kind == "edit_fix_exposition":
|
||||||
|
return PRESET_DIVINE_INTELLECT # Assuming fixing exposition uses the same preset as divine intellect
|
||||||
|
else:
|
||||||
|
return PRESET_SIMPLE_1 # Default preset if none of the kinds match
|
||||||
|
|
||||||
|
def max_tokens_for_kind(kind: str, total_budget: int):
|
||||||
|
if kind == "conversation":
|
||||||
|
return 75 # Example value, adjust as needed
|
||||||
|
elif kind == "conversation_old":
|
||||||
|
return 75 # Example value, adjust as needed
|
||||||
|
elif kind == "conversation_long":
|
||||||
|
return 300 # Example value, adjust as needed
|
||||||
|
elif kind == "conversation_select_talking_actor":
|
||||||
|
return 30 # Example value, adjust as needed
|
||||||
|
elif kind == "summarize":
|
||||||
|
return 500 # Example value, adjust as needed
|
||||||
|
elif kind == "analyze":
|
||||||
|
return 500 # Example value, adjust as needed
|
||||||
|
elif kind == "analyze_creative":
|
||||||
|
return 1024 # Example value, adjust as needed
|
||||||
|
elif kind == "analyze_long":
|
||||||
|
return 2048 # Example value, adjust as needed
|
||||||
|
elif kind == "analyze_freeform":
|
||||||
|
return 500 # Example value, adjust as needed
|
||||||
|
elif kind == "analyze_freeform_short":
|
||||||
|
return 10 # Example value, adjust as needed
|
||||||
|
elif kind == "narrate":
|
||||||
|
return 500 # Example value, adjust as needed
|
||||||
|
elif kind == "story":
|
||||||
|
return 300 # Example value, adjust as needed
|
||||||
|
elif kind == "create":
|
||||||
|
return min(1024, int(total_budget * 0.35)) # Example calculation, adjust as needed
|
||||||
|
elif kind == "create_concise":
|
||||||
|
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||||
|
elif kind == "create_precise":
|
||||||
|
return min(400, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||||
|
elif kind == "director":
|
||||||
|
return min(600, int(total_budget * 0.25)) # Example calculation, adjust as needed
|
||||||
|
elif kind == "director_short":
|
||||||
|
return 25 # Example value, adjust as needed
|
||||||
|
elif kind == "director_yesno":
|
||||||
|
return 2 # Example value, adjust as needed
|
||||||
|
elif kind == "edit_dialogue":
|
||||||
|
return 100 # Example value, adjust as needed
|
||||||
|
elif kind == "edit_add_detail":
|
||||||
|
return 200 # Example value, adjust as needed
|
||||||
|
elif kind == "edit_fix_exposition":
|
||||||
|
return 1024 # Example value, adjust as needed
|
||||||
|
else:
|
||||||
|
return 150 # Default value if none of the kinds match
|
|
@ -67,9 +67,9 @@ def _client_bootstrap(client_type: ClientType, pod):
|
||||||
id = pod["id"]
|
id = pod["id"]
|
||||||
|
|
||||||
if client_type == ClientType.textgen:
|
if client_type == ClientType.textgen:
|
||||||
api_url = f"https://{id}-5000.proxy.runpod.net/api"
|
api_url = f"https://{id}-5000.proxy.runpod.net"
|
||||||
elif client_type == ClientType.automatic1111:
|
elif client_type == ClientType.automatic1111:
|
||||||
api_url = f"https://{id}-5000.proxy.runpod.net/api"
|
api_url = f"https://{id}-5000.proxy.runpod.net"
|
||||||
|
|
||||||
return ClientBootstrap(
|
return ClientBootstrap(
|
||||||
client_type=client_type,
|
client_type=client_type,
|
||||||
|
|
|
@ -1,735 +1,61 @@
|
||||||
import asyncio
|
from talemate.client.base import ClientBase, STOPPING_STRINGS
|
||||||
import random
|
|
||||||
import json
|
|
||||||
import copy
|
|
||||||
import structlog
|
|
||||||
import time
|
|
||||||
import httpx
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Callable, Union
|
|
||||||
import logging
|
|
||||||
import talemate.util as util
|
|
||||||
from talemate.client.registry import register
|
from talemate.client.registry import register
|
||||||
import talemate.client.system_prompts as system_prompts
|
from openai import AsyncOpenAI
|
||||||
from talemate.emit import Emission, emit
|
import httpx
|
||||||
from talemate.client.context import client_context_attribute
|
import copy
|
||||||
from talemate.client.model_prompts import model_prompt
|
import random
|
||||||
|
|
||||||
import talemate.instance as instance
|
|
||||||
|
|
||||||
log = structlog.get_logger(__name__)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TaleMateClient",
|
|
||||||
"RestApiTaleMateClient",
|
|
||||||
"TextGeneratorWebuiClient",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Set up logging level for httpx to WARNING to suppress debug logs.
|
|
||||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
class DefaultContext(int):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
PRESET_TALEMATE_LEGACY = {
|
|
||||||
"temperature": 0.72,
|
|
||||||
"top_p": 0.73,
|
|
||||||
"top_k": 0,
|
|
||||||
"top_a": 0,
|
|
||||||
"repetition_penalty": 1.18,
|
|
||||||
"repetition_penalty_range": 2048,
|
|
||||||
"encoder_repetition_penalty": 1,
|
|
||||||
#"encoder_repetition_penalty": 1.2,
|
|
||||||
#"no_repeat_ngram_size": 2,
|
|
||||||
"do_sample": True,
|
|
||||||
"length_penalty": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
PRESET_TALEMATE_CONVERSATION = {
|
|
||||||
"temperature": 0.65,
|
|
||||||
"top_p": 0.47,
|
|
||||||
"top_k": 42,
|
|
||||||
"typical_p": 1,
|
|
||||||
"top_a": 0,
|
|
||||||
"tfs": 1,
|
|
||||||
"epsilon_cutoff": 0,
|
|
||||||
"eta_cutoff": 0,
|
|
||||||
"repetition_penalty": 1.18,
|
|
||||||
"repetition_penalty_range": 2048,
|
|
||||||
"no_repeat_ngram_size": 0,
|
|
||||||
"penalty_alpha": 0,
|
|
||||||
"num_beams": 1,
|
|
||||||
"length_penalty": 1,
|
|
||||||
"min_length": 0,
|
|
||||||
"encoder_rep_pen": 1,
|
|
||||||
"do_sample": True,
|
|
||||||
"early_stopping": False,
|
|
||||||
"mirostat_mode": 0,
|
|
||||||
"mirostat_tau": 5,
|
|
||||||
"mirostat_eta": 0.1
|
|
||||||
}
|
|
||||||
|
|
||||||
PRESET_TALEMATE_CREATOR = {
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.15,
|
|
||||||
"repetition_penalty_range": 512,
|
|
||||||
"top_k": 20,
|
|
||||||
"do_sample": True,
|
|
||||||
"length_penalty": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
PRESET_LLAMA_PRECISE = {
|
|
||||||
'temperature': 0.7,
|
|
||||||
'top_p': 0.1,
|
|
||||||
'repetition_penalty': 1.18,
|
|
||||||
'top_k': 40
|
|
||||||
}
|
|
||||||
|
|
||||||
PRESET_KOBOLD_GODLIKE = {
|
|
||||||
'temperature': 0.7,
|
|
||||||
'top_p': 0.5,
|
|
||||||
'typical_p': 0.19,
|
|
||||||
'repetition_penalty': 1.1,
|
|
||||||
"repetition_penalty_range": 1024,
|
|
||||||
}
|
|
||||||
|
|
||||||
PRESET_DIVINE_INTELLECT = {
|
|
||||||
'temperature': 1.31,
|
|
||||||
'top_p': 0.14,
|
|
||||||
"repetition_penalty_range": 1024,
|
|
||||||
'repetition_penalty': 1.17,
|
|
||||||
'top_k': 49,
|
|
||||||
"mirostat_mode": 0,
|
|
||||||
"mirostat_tau": 5,
|
|
||||||
"mirostat_eta": 0.1,
|
|
||||||
"tfs": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
PRESET_SIMPLE_1 = {
|
|
||||||
"temperature": 0.7,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"repetition_penalty": 1.15,
|
|
||||||
"top_k": 20,
|
|
||||||
}
|
|
||||||
|
|
||||||
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
|
|
||||||
"""
|
|
||||||
adjusts temperature and repetition_penalty
|
|
||||||
by random values using the base value as a center
|
|
||||||
"""
|
|
||||||
|
|
||||||
temp = prompt_config["temperature"]
|
|
||||||
rep_pen = prompt_config["repetition_penalty"]
|
|
||||||
|
|
||||||
copied_config = copy.deepcopy(prompt_config)
|
|
||||||
|
|
||||||
min_offset = offset * 0.3
|
|
||||||
|
|
||||||
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
|
||||||
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
|
|
||||||
|
|
||||||
return copied_config
|
|
||||||
|
|
||||||
|
|
||||||
class TaleMateClient:
|
|
||||||
"""
|
|
||||||
An abstract TaleMate client that can be implemented for different communication methods with the AI.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_url: str,
|
|
||||||
max_token_length: Union[int, DefaultContext] = int.__new__(
|
|
||||||
DefaultContext, 2048
|
|
||||||
),
|
|
||||||
):
|
|
||||||
self.api_url = api_url
|
|
||||||
self.name = "generic_client"
|
|
||||||
self.model_name = None
|
|
||||||
self.last_token_length = 0
|
|
||||||
self.max_token_length = max_token_length
|
|
||||||
self.original_max_token_length = max_token_length
|
|
||||||
self.enabled = True
|
|
||||||
self.current_status = None
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def send_message(self, message: dict) -> str:
|
|
||||||
"""
|
|
||||||
Sends a message to the AI. Needs to be implemented by the subclass.
|
|
||||||
:param message: The message to be sent.
|
|
||||||
:return: The AI's response text.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def send_prompt(self, prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
Sends a prompt to the AI. Needs to be implemented by the subclass.
|
|
||||||
:param prompt: The text prompt to send.
|
|
||||||
:return: The AI's response text.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reconfigure(self, **kwargs):
|
|
||||||
if "api_url" in kwargs:
|
|
||||||
self.api_url = kwargs["api_url"]
|
|
||||||
|
|
||||||
if "max_token_length" in kwargs:
|
|
||||||
self.max_token_length = kwargs["max_token_length"]
|
|
||||||
|
|
||||||
if "enabled" in kwargs:
|
|
||||||
self.enabled = bool(kwargs["enabled"])
|
|
||||||
|
|
||||||
def remaining_tokens(self, context: Union[str, list]) -> int:
|
|
||||||
return self.max_token_length - util.count_tokens(context)
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_template(self, sys_msg, prompt):
|
|
||||||
return model_prompt(self.model_name, sys_msg, prompt)
|
|
||||||
|
|
||||||
class RESTTaleMateClient(TaleMateClient, ABC):
|
|
||||||
"""
|
|
||||||
A RESTful TaleMate client that connects to the REST API endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def send_message(self, message: dict, url: str) -> str:
|
|
||||||
"""
|
|
||||||
Sends a message to the REST API and returns the AI's response.
|
|
||||||
:param message: The message to be sent.
|
|
||||||
:return: The AI's response text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(url, json=message, timeout=None)
|
|
||||||
response_data = response.json()
|
|
||||||
return response_data["results"][0]["text"]
|
|
||||||
except KeyError:
|
|
||||||
return response_data["results"][0]["history"]["visible"][0][-1]
|
|
||||||
|
|
||||||
|
|
||||||
@register()
|
@register()
|
||||||
class TextGeneratorWebuiClient(RESTTaleMateClient):
|
class TextGeneratorWebuiClient(ClientBase):
|
||||||
"""
|
|
||||||
Client that connects to the text-generatior-webui api
|
|
||||||
"""
|
|
||||||
|
|
||||||
client_type = "textgenwebui"
|
client_type = "textgenwebui"
|
||||||
conversation_retries = 5
|
|
||||||
|
def tune_prompt_parameters(self, parameters:dict, kind:str):
|
||||||
|
super().tune_prompt_parameters(parameters, kind)
|
||||||
|
parameters["stopping_strings"] = STOPPING_STRINGS + parameters.get("extra_stopping_strings", [])
|
||||||
|
# is this needed?
|
||||||
|
parameters["max_new_tokens"] = parameters["max_tokens"]
|
||||||
|
|
||||||
def __init__(self, api_url: str, max_token_length: int = 2048, **kwargs):
|
def set_client(self):
|
||||||
|
self.client = AsyncOpenAI(base_url=self.api_url+"/v1", api_key="sk-1111")
|
||||||
|
|
||||||
|
async def get_model_name(self):
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(f"{self.api_url}/v1/internal/model/info", timeout=2)
|
||||||
|
if response.status_code == 404:
|
||||||
|
raise Exception("Could not find model info (wrong api version?)")
|
||||||
|
response_data = response.json()
|
||||||
|
model_name = response_data.get("model_name")
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
async def generate(self, prompt:str, parameters:dict, kind:str):
|
||||||
|
|
||||||
api_url = self.cleanup_api_url(api_url)
|
|
||||||
|
|
||||||
self.api_url_base = api_url
|
|
||||||
api_url = f"{api_url}/v1/chat"
|
|
||||||
super().__init__(api_url, max_token_length=max_token_length)
|
|
||||||
self.model_name = None
|
|
||||||
self.limited_ram = False
|
|
||||||
self.name = kwargs.get("name", "textgenwebui")
|
|
||||||
self.processing = False
|
|
||||||
self.connected = False
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"TextGeneratorWebuiClient[{self.api_url_base}][{self.model_name or ''}]"
|
|
||||||
|
|
||||||
def cleanup_api_url(self, api_url:str):
|
|
||||||
"""
|
"""
|
||||||
Strips trailing / and ensures endpoint is /api
|
Generates text from the given prompt and parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if api_url.endswith("/"):
|
headers = {}
|
||||||
api_url = api_url[:-1]
|
headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
if not api_url.endswith("/api"):
|
|
||||||
api_url = api_url + "/api"
|
|
||||||
|
|
||||||
return api_url
|
|
||||||
|
|
||||||
def reconfigure(self, **kwargs):
|
|
||||||
super().reconfigure(**kwargs)
|
|
||||||
if "api_url" in kwargs:
|
|
||||||
log.debug("reconfigure", api_url=kwargs["api_url"])
|
|
||||||
api_url = kwargs["api_url"]
|
|
||||||
api_url = self.cleanup_api_url(api_url)
|
|
||||||
self.api_url_base = api_url
|
|
||||||
self.api_url = api_url
|
|
||||||
|
|
||||||
def toggle_disabled_if_remote(self):
|
parameters["prompt"] = prompt.strip()
|
||||||
|
|
||||||
remote_servies = [
|
async with httpx.AsyncClient() as client:
|
||||||
".runpod.net"
|
response = await client.post(f"{self.api_url}/v1/completions", json=parameters, timeout=None, headers=headers)
|
||||||
]
|
|
||||||
|
|
||||||
for service in remote_servies:
|
|
||||||
if service in self.api_url_base:
|
|
||||||
self.enabled = False
|
|
||||||
return
|
|
||||||
|
|
||||||
def emit_status(self, processing: bool = None):
|
|
||||||
if processing is not None:
|
|
||||||
self.processing = processing
|
|
||||||
|
|
||||||
if not self.enabled:
|
|
||||||
status = "disabled"
|
|
||||||
model_name = "Disabled"
|
|
||||||
elif not self.connected:
|
|
||||||
status = "error"
|
|
||||||
model_name = "Could not connect"
|
|
||||||
elif self.model_name:
|
|
||||||
status = "busy" if self.processing else "idle"
|
|
||||||
model_name = self.model_name
|
|
||||||
else:
|
|
||||||
model_name = "No model loaded"
|
|
||||||
status = "warning"
|
|
||||||
|
|
||||||
status_change = status != self.current_status
|
|
||||||
self.current_status = status
|
|
||||||
|
|
||||||
emit(
|
|
||||||
"client_status",
|
|
||||||
message=self.client_type,
|
|
||||||
id=self.name,
|
|
||||||
details=model_name,
|
|
||||||
status=status,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if status_change:
|
|
||||||
instance.emit_agent_status_by_client(self)
|
|
||||||
|
|
||||||
|
|
||||||
# Add the 'status' method
|
|
||||||
async def status(self):
|
|
||||||
"""
|
|
||||||
Send a request to the API to retrieve the loaded AI model name.
|
|
||||||
Raises an error if no model name is returned.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not self.enabled:
|
|
||||||
self.connected = False
|
|
||||||
self.emit_status()
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(f"{self.api_url_base}/v1/model", timeout=2)
|
|
||||||
|
|
||||||
except (
|
|
||||||
httpx.TimeoutException,
|
|
||||||
httpx.NetworkError,
|
|
||||||
):
|
|
||||||
self.model_name = None
|
|
||||||
self.connected = False
|
|
||||||
self.toggle_disabled_if_remote()
|
|
||||||
self.emit_status()
|
|
||||||
return
|
|
||||||
|
|
||||||
self.connected = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
self.enabled = True
|
return response_data["choices"][0]["text"]
|
||||||
except json.decoder.JSONDecodeError as e:
|
|
||||||
self.connected = False
|
def jiggle_randomness(self, prompt_config:dict, offset:float=0.3) -> dict:
|
||||||
self.toggle_disabled_if_remote()
|
|
||||||
if not self.enabled:
|
|
||||||
log.warn("remote service unreachable, disabling client", name=self.name)
|
|
||||||
else:
|
|
||||||
log.error("client response error", name=self.name, e=e)
|
|
||||||
|
|
||||||
self.emit_status()
|
|
||||||
return
|
|
||||||
|
|
||||||
model_name = response_data.get("result")
|
|
||||||
|
|
||||||
if not model_name or model_name == "None":
|
|
||||||
log.warning("client model not loaded", client=self.name)
|
|
||||||
self.emit_status()
|
|
||||||
return
|
|
||||||
|
|
||||||
model_changed = model_name != self.model_name
|
|
||||||
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
if model_changed:
|
|
||||||
self.auto_context_length()
|
|
||||||
|
|
||||||
log.info(f"{self} [{self.max_token_length} ctx]: ready")
|
|
||||||
self.emit_status()
|
|
||||||
|
|
||||||
def auto_context_length(self):
|
|
||||||
"""
|
"""
|
||||||
Automaticalle sets context length based on LLM
|
adjusts temperature and repetition_penalty
|
||||||
"""
|
by random values using the base value as a center
|
||||||
|
|
||||||
if not isinstance(self.max_token_length, DefaultContext):
|
|
||||||
# context length was specified manually
|
|
||||||
return
|
|
||||||
|
|
||||||
model_name = self.model_name.lower()
|
|
||||||
|
|
||||||
if "longchat" in model_name:
|
|
||||||
self.max_token_length = 16000
|
|
||||||
elif "8k" in model_name:
|
|
||||||
if not self.limited_ram or "13b" in model_name:
|
|
||||||
self.max_token_length = 6000
|
|
||||||
else:
|
|
||||||
self.max_token_length = 4096
|
|
||||||
elif "4k" in model_name:
|
|
||||||
self.max_token_length = 4096
|
|
||||||
else:
|
|
||||||
self.max_token_length = self.original_max_token_length
|
|
||||||
|
|
||||||
@property
|
|
||||||
def instruction_template(self):
|
|
||||||
if "vicuna" in self.model_name.lower():
|
|
||||||
return "Vicuna-v1.1"
|
|
||||||
if "camel" in self.model_name.lower():
|
|
||||||
return "Vicuna-v1.1"
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def prompt_url(self):
|
|
||||||
return self.api_url_base + "/v1/generate"
|
|
||||||
|
|
||||||
def prompt_config_conversation_old(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.BASIC,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": 75,
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
}
|
|
||||||
config.update(PRESET_TALEMATE_CONVERSATION)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_config_conversation(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.ROLEPLAY,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
stopping_strings = ["<|end_of_turn|>"]
|
|
||||||
|
|
||||||
conversation_context = client_context_attribute("conversation")
|
|
||||||
|
|
||||||
stopping_strings += [
|
|
||||||
f"{character}:" for character in conversation_context["other_characters"]
|
|
||||||
]
|
|
||||||
|
|
||||||
max_new_tokens = conversation_context.get("length", 96)
|
|
||||||
log.debug("prompt_config_conversation", stopping_strings=stopping_strings, conversation_context=conversation_context, max_new_tokens=max_new_tokens)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
"stopping_strings": stopping_strings,
|
|
||||||
}
|
|
||||||
config.update(PRESET_TALEMATE_CONVERSATION)
|
|
||||||
|
|
||||||
jiggle_randomness(config)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_conversation_long(self, prompt: str) -> dict:
|
|
||||||
config = self.prompt_config_conversation(prompt)
|
|
||||||
config["max_new_tokens"] = 300
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_conversation_select_talking_actor(self, prompt: str) -> dict:
|
|
||||||
config = self.prompt_config_conversation(prompt)
|
|
||||||
config["max_new_tokens"] = 30
|
|
||||||
config["stopping_strings"] += [":"]
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_config_summarize(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.NARRATOR,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": 500,
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
}
|
|
||||||
|
|
||||||
config.update(PRESET_LLAMA_PRECISE)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_analyze(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.ANALYST,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": 500,
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
}
|
|
||||||
|
|
||||||
config.update(PRESET_SIMPLE_1)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_analyze_creative(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.ANALYST,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {}
|
|
||||||
config.update(PRESET_DIVINE_INTELLECT)
|
|
||||||
config.update({
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": 1024,
|
|
||||||
"repetition_penalty_range": 1024,
|
|
||||||
"truncation_length": self.max_token_length
|
|
||||||
})
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_analyze_long(self, prompt: str) -> dict:
|
|
||||||
config = self.prompt_config_analyze(prompt)
|
|
||||||
config["max_new_tokens"] = 2048
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_analyze_freeform(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.ANALYST_FREEFORM,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": 500,
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
}
|
|
||||||
|
|
||||||
config.update(PRESET_LLAMA_PRECISE)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_config_analyze_freeform_short(self, prompt: str) -> dict:
|
|
||||||
config = self.prompt_config_analyze_freeform(prompt)
|
|
||||||
config["max_new_tokens"] = 10
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_narrate(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.NARRATOR,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": 500,
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
}
|
|
||||||
config.update(PRESET_LLAMA_PRECISE)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_story(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.NARRATOR,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": 300,
|
|
||||||
"seed": random.randint(0, 1000000000),
|
|
||||||
"truncation_length": self.max_token_length
|
|
||||||
}
|
|
||||||
config.update(PRESET_DIVINE_INTELLECT)
|
|
||||||
config.update({
|
|
||||||
"repetition_penalty": 1.3,
|
|
||||||
"repetition_penalty_range": 2048,
|
|
||||||
})
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_create(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.CREATOR,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": min(1024, self.max_token_length * 0.35),
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
}
|
|
||||||
config.update(PRESET_TALEMATE_CREATOR)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_create_concise(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.CREATOR,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": min(400, self.max_token_length * 0.25),
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
"stopping_strings": ["<|DONE|>", "\n\n"]
|
|
||||||
}
|
|
||||||
config.update(PRESET_TALEMATE_CREATOR)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_create_precise(self, prompt: str) -> dict:
|
|
||||||
config = self.prompt_config_create_concise(prompt)
|
|
||||||
config.update(PRESET_LLAMA_PRECISE)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_director(self, prompt: str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.DIRECTOR,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": min(600, self.max_token_length * 0.25),
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
}
|
|
||||||
config.update(PRESET_SIMPLE_1)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_config_director_short(self, prompt: str) -> dict:
|
|
||||||
config = self.prompt_config_director(prompt)
|
|
||||||
config.update(max_new_tokens=25)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_director_yesno(self, prompt: str) -> dict:
|
|
||||||
config = self.prompt_config_director(prompt)
|
|
||||||
config.update(max_new_tokens=2)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_edit_dialogue(self, prompt:str) -> dict:
|
|
||||||
prompt = self.prompt_template(
|
|
||||||
system_prompts.EDITOR,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
conversation_context = client_context_attribute("conversation")
|
|
||||||
|
|
||||||
stopping_strings = [
|
|
||||||
f"{character}:" for character in conversation_context["other_characters"]
|
|
||||||
]
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_new_tokens": 100,
|
|
||||||
"truncation_length": self.max_token_length,
|
|
||||||
"stopping_strings": stopping_strings,
|
|
||||||
}
|
|
||||||
|
|
||||||
config.update(PRESET_DIVINE_INTELLECT)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
def prompt_config_edit_add_detail(self, prompt:str) -> dict:
|
|
||||||
|
|
||||||
config = self.prompt_config_edit_dialogue(prompt)
|
|
||||||
config.update(max_new_tokens=200)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_config_edit_fix_exposition(self, prompt:str) -> dict:
|
|
||||||
|
|
||||||
config = self.prompt_config_edit_dialogue(prompt)
|
|
||||||
config.update(max_new_tokens=1024)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
async def send_prompt(
|
|
||||||
self, prompt: str, kind: str = "conversation", finalize: Callable = lambda x: x
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Send a prompt to the AI and return its response.
|
|
||||||
:param prompt: The text prompt to send.
|
|
||||||
:return: The AI's response text.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#prompt = prompt.replace("<|BOT|>", "<|BOT|>Certainly! ")
|
temp = prompt_config["temperature"]
|
||||||
|
rep_pen = prompt_config["repetition_penalty"]
|
||||||
await self.status()
|
|
||||||
self.emit_status(processing=True)
|
|
||||||
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
fn_prompt_config = getattr(self, f"prompt_config_{kind}")
|
|
||||||
fn_url = self.prompt_url
|
|
||||||
message = fn_prompt_config(prompt)
|
|
||||||
|
|
||||||
if client_context_attribute("nuke_repetition") > 0.0 and kind in ["conversation", "story"]:
|
|
||||||
log.info("nuke repetition", offset=client_context_attribute("nuke_repetition"), temperature=message["temperature"], repetition_penalty=message["repetition_penalty"])
|
|
||||||
message = jiggle_randomness(message, offset=client_context_attribute("nuke_repetition"))
|
|
||||||
log.info("nuke repetition (applied)", offset=client_context_attribute("nuke_repetition"), temperature=message["temperature"], repetition_penalty=message["repetition_penalty"])
|
|
||||||
|
|
||||||
message = finalize(message)
|
min_offset = offset * 0.3
|
||||||
|
|
||||||
token_length = int(len(message["prompt"]) / 3.6)
|
prompt_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||||
|
prompt_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
|
||||||
self.last_token_length = token_length
|
|
||||||
|
|
||||||
log.debug("send_prompt", token_length=token_length, max_token_length=self.max_token_length)
|
|
||||||
|
|
||||||
message["prompt"] = message["prompt"].strip()
|
|
||||||
|
|
||||||
#print(f"prompt: |{message['prompt']}|")
|
|
||||||
|
|
||||||
# add <|im_end|> to stopping strings
|
|
||||||
if "stopping_strings" in message:
|
|
||||||
message["stopping_strings"] += ["<|im_end|>", "</s>"]
|
|
||||||
else:
|
|
||||||
message["stopping_strings"] = ["<|im_end|>", "</s>"]
|
|
||||||
|
|
||||||
#message["seed"] = -1
|
|
||||||
|
|
||||||
#for k,v in message.items():
|
|
||||||
# if k == "prompt":
|
|
||||||
# continue
|
|
||||||
# print(f"{k}: {v}")
|
|
||||||
|
|
||||||
time_start = time.time()
|
|
||||||
|
|
||||||
response = await self.send_message(message, fn_url())
|
|
||||||
|
|
||||||
time_end = time.time()
|
|
||||||
|
|
||||||
response = response.split("#")[0]
|
|
||||||
self.emit_status(processing=False)
|
|
||||||
|
|
||||||
emit("prompt_sent", data={
|
|
||||||
"kind": kind,
|
|
||||||
"prompt": message["prompt"],
|
|
||||||
"response": response,
|
|
||||||
"prompt_tokens": token_length,
|
|
||||||
"response_tokens": int(len(response) / 3.6),
|
|
||||||
"time": time_end - time_start,
|
|
||||||
})
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAPIClient(RESTTaleMateClient):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GPT3Client(OpenAPIClient):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GPT4Client(OpenAPIClient):
|
|
||||||
pass
|
|
32
src/talemate/client/utils.py
Normal file
32
src/talemate/client/utils.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import copy
|
||||||
|
import random
|
||||||
|
|
||||||
|
def jiggle_randomness(prompt_config:dict, offset:float=0.3) -> dict:
|
||||||
|
"""
|
||||||
|
adjusts temperature and repetition_penalty
|
||||||
|
by random values using the base value as a center
|
||||||
|
"""
|
||||||
|
|
||||||
|
temp = prompt_config["temperature"]
|
||||||
|
rep_pen = prompt_config["repetition_penalty"]
|
||||||
|
|
||||||
|
copied_config = copy.deepcopy(prompt_config)
|
||||||
|
|
||||||
|
min_offset = offset * 0.3
|
||||||
|
|
||||||
|
copied_config["temperature"] = random.uniform(temp + min_offset, temp + offset)
|
||||||
|
copied_config["repetition_penalty"] = random.uniform(rep_pen + min_offset * 0.3, rep_pen + offset * 0.3)
|
||||||
|
|
||||||
|
return copied_config
|
||||||
|
|
||||||
|
|
||||||
|
def jiggle_enabled_for(kind:str):
|
||||||
|
|
||||||
|
if kind in ["conversation", "story"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if kind.startswith("narrate"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
|
@ -17,7 +17,26 @@ class CmdRename(TalemateCommand):
|
||||||
aliases = []
|
aliases = []
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
# collect list of characters in the scene
|
||||||
|
|
||||||
|
if self.args:
|
||||||
|
character_name = self.args[0]
|
||||||
|
else:
|
||||||
|
character_names = self.scene.character_names
|
||||||
|
character_name = await wait_for_input("Which character do you want to rename?", data={
|
||||||
|
"input_type": "select",
|
||||||
|
"choices": character_names,
|
||||||
|
})
|
||||||
|
|
||||||
|
character = self.scene.get_character(character_name)
|
||||||
|
|
||||||
|
if not character:
|
||||||
|
self.system_message(f"Character {character_name} not found")
|
||||||
|
return True
|
||||||
|
|
||||||
name = await wait_for_input("Enter new name: ")
|
name = await wait_for_input("Enter new name: ")
|
||||||
|
|
||||||
self.scene.main_character.character.rename(name)
|
character.rename(name)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from talemate.tale_mate import Scene
|
from talemate.tale_mate import Scene, Actor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Event",
|
"Event",
|
||||||
|
@ -42,4 +42,8 @@ class GameLoopEvent(Event):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GameLoopStartEvent(GameLoopEvent):
|
class GameLoopStartEvent(GameLoopEvent):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GameLoopActorIterEvent(GameLoopEvent):
|
||||||
|
actor: Actor
|
|
@ -290,6 +290,7 @@ class Prompt:
|
||||||
env.globals["query_scene"] = self.query_scene
|
env.globals["query_scene"] = self.query_scene
|
||||||
env.globals["query_memory"] = self.query_memory
|
env.globals["query_memory"] = self.query_memory
|
||||||
env.globals["query_text"] = self.query_text
|
env.globals["query_text"] = self.query_text
|
||||||
|
env.globals["instruct_text"] = self.instruct_text
|
||||||
env.globals["retrieve_memories"] = self.retrieve_memories
|
env.globals["retrieve_memories"] = self.retrieve_memories
|
||||||
env.globals["uuidgen"] = lambda: str(uuid.uuid4())
|
env.globals["uuidgen"] = lambda: str(uuid.uuid4())
|
||||||
env.globals["to_int"] = lambda x: int(x)
|
env.globals["to_int"] = lambda x: int(x)
|
||||||
|
@ -394,9 +395,14 @@ class Prompt:
|
||||||
f"Answer: " + loop.run_until_complete(memory.query(query, **kwargs)),
|
f"Answer: " + loop.run_until_complete(memory.query(query, **kwargs)),
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
return loop.run_until_complete(memory.multi_query([query], **kwargs))
|
return loop.run_until_complete(memory.multi_query(query.split("\n"), **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def instruct_text(self, instruction:str, text:str):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
world_state = instance.get_agent("world_state")
|
||||||
|
instruction = instruction.format(**self.vars)
|
||||||
|
|
||||||
|
return loop.run_until_complete(world_state.analyze_and_follow_instruction(text, instruction))
|
||||||
|
|
||||||
def retrieve_memories(self, lines:list[str], goal:str=None):
|
def retrieve_memories(self, lines:list[str], goal:str=None):
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
{% block rendered_context -%}
|
||||||
|
<|SECTION:CONTEXT|>
|
||||||
|
Content Context: This is a specific scene from {{ scene.context }}
|
||||||
|
Scenario Premise: {{ scene.description }}
|
||||||
|
{% for memory in query_memory(last_line, as_question_answer=False, iterate=10) -%}
|
||||||
|
{{ memory }}
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
{% endblock -%}
|
||||||
|
<|CLOSE_SECTION|>
|
||||||
|
{% for scene_context in scene.context_history(budget=max_tokens-200-count_tokens(self.rendered_context())) -%}
|
||||||
|
{{ scene_context }}
|
||||||
|
{% endfor %}
|
||||||
|
<|SECTION:TASK|>
|
||||||
|
Based on the previous line '{{ last_line }}', create the next line of narration. This line should focus solely on describing sensory details (like sounds, sights, smells, tactile sensations) or external actions that move the story forward. Avoid including any character's internal thoughts, feelings, or dialogue. Your narration should directly respond to '{{ last_line }}', either by elaborating on the immediate scene or by subtly advancing the plot. Generate exactly one sentence of new narration. If the character is trying to determine some state, truth or situation, try to answer as part of the narration.
|
||||||
|
|
||||||
|
Be creative and generate something new and interesting.
|
||||||
|
<|CLOSE_SECTION|>
|
||||||
|
{{ set_prepared_response('*') }}
|
|
@ -8,13 +8,13 @@
|
||||||
{% if query.endswith("?") -%}
|
{% if query.endswith("?") -%}
|
||||||
Question: {{ query }}
|
Question: {{ query }}
|
||||||
Extra context: {{ query_memory(query, as_question_answer=False) }}
|
Extra context: {{ query_memory(query, as_question_answer=False) }}
|
||||||
Instruction: Analyze Context, History and Dialogue. Be factual and truthful. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Respect the scene progression and answer in the context of the end of the dialogue.
|
Instruction: Analyze Context, History and Dialogue. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context. Respect the scene progression and answer in the context of the end of the dialogue.
|
||||||
{% else -%}
|
{% else -%}
|
||||||
Instruction: {{ query }}
|
Instruction: {{ query }}
|
||||||
Extra context: {{ query_memory(query, as_question_answer=False) }}
|
Extra context: {{ query_memory(query, as_question_answer=False) }}
|
||||||
Answer based on Context, History and Dialogue. Be factual and truthful. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context.
|
Answer based on Context, History and Dialogue. When evaluating both story and memory, story is more important. You can fill in gaps using imagination as long as it is based on the existing context.
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
Content Context: This is a specific scene from {{ scene.context }}
|
Content Context: This is a specific scene from {{ scene.context }}
|
||||||
Narration style: point and click adventure game from the 90s
|
Your answer should be in the style of short narration that fits the context of the scene.
|
||||||
<|CLOSE_SECTION|>
|
<|CLOSE_SECTION|>
|
||||||
Narrator answers: {% if at_the_end %}{{ bot_token }}At the end of the dialogue, {% endif %}
|
Narrator answers: {% if at_the_end %}{{ bot_token }}At the end of the dialogue, {% endif %}
|
|
@ -8,9 +8,10 @@
|
||||||
<|SECTION:TASK|>
|
<|SECTION:TASK|>
|
||||||
Answer the following questions:
|
Answer the following questions:
|
||||||
|
|
||||||
{{ query_text("What are 1 to 3 questions to ask the narrator of the story to gather more context from the past for the continuation of this conversation? If a character is asking about a status, location or information about an item or another character, make sure to include question(s) that help gather context for this. Don't explain your reasoning. Don't ask the actors directly.", text, as_question_answer=False) }}
|
{{ instruct_text("Ask the narrator three (3) questions to gather more context from the past for the continuation of this conversation. If a character is asking about a state, location or information about an item or another character, make sure to include question(s) that help gather context for this.", text) }}
|
||||||
|
|
||||||
You answers should be precise, truthful and short.
|
You answers should be precise, truthful and short. Pay close attention to timestamps when retrieving information from the context.
|
||||||
|
|
||||||
<|CLOSE_SECTION|>
|
<|CLOSE_SECTION|>
|
||||||
<|SECTION:RELEVANT CONTEXT|>
|
<|SECTION:RELEVANT CONTEXT|>
|
||||||
|
{{ bot_token }}Answers:
|
|
@ -0,0 +1,5 @@
|
||||||
|
|
||||||
|
{{ text }}
|
||||||
|
|
||||||
|
<|SECTION:TASK|>
|
||||||
|
{{ instruction }}
|
|
@ -34,7 +34,7 @@ No dialogue so far
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
<|CLOSE_SECTION|>
|
<|CLOSE_SECTION|>
|
||||||
<|SECTION:SCENE PROGRESS|>
|
<|SECTION:SCENE PROGRESS|>
|
||||||
{% for scene_context in scene.context_history(budget=300, min_dialogue=5, add_archieved_history=False, max_dialogue=5) -%}
|
{% for scene_context in scene.context_history(budget=500, min_dialogue=5, add_archieved_history=False, max_dialogue=5) -%}
|
||||||
{{ scene_context }}
|
{{ scene_context }}
|
||||||
{% endfor -%}
|
{% endfor -%}
|
||||||
<|CLOSE_SECTION|>
|
<|CLOSE_SECTION|>
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -167,14 +167,14 @@ class WebsocketHandler(Receiver):
|
||||||
log.info("Configuring clients", clients=clients)
|
log.info("Configuring clients", clients=clients)
|
||||||
|
|
||||||
for client in clients:
|
for client in clients:
|
||||||
if client["type"] == "textgenwebui":
|
if client["type"] in ["textgenwebui", "lmstudio"]:
|
||||||
try:
|
try:
|
||||||
max_token_length = int(client.get("max_token_length", 2048))
|
max_token_length = int(client.get("max_token_length", 2048))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.llm_clients[client["name"]] = {
|
self.llm_clients[client["name"]] = {
|
||||||
"type": "textgenwebui",
|
"type": client["type"],
|
||||||
"api_url": client["apiUrl"],
|
"api_url": client["apiUrl"],
|
||||||
"name": client["name"],
|
"name": client["name"],
|
||||||
"max_token_length": max_token_length,
|
"max_token_length": max_token_length,
|
||||||
|
@ -385,7 +385,7 @@ class WebsocketHandler(Receiver):
|
||||||
"status": emission.status,
|
"status": emission.status,
|
||||||
"data": emission.data,
|
"data": emission.data,
|
||||||
"max_token_length": client.max_token_length if client else 2048,
|
"max_token_length": client.max_token_length if client else 2048,
|
||||||
"apiUrl": getattr(client, "api_url_base", None) if client else None,
|
"apiUrl": getattr(client, "api_url", None) if client else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,10 @@ __all__ = [
|
||||||
|
|
||||||
log = structlog.get_logger("talemate")
|
log = structlog.get_logger("talemate")
|
||||||
|
|
||||||
|
async_signals.register("game_loop_start")
|
||||||
|
async_signals.register("game_loop")
|
||||||
|
async_signals.register("game_loop_actor_iter")
|
||||||
|
|
||||||
|
|
||||||
class Character:
|
class Character:
|
||||||
"""
|
"""
|
||||||
|
@ -523,8 +527,6 @@ class Player(Actor):
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
async_signals.register("game_loop_start")
|
|
||||||
async_signals.register("game_loop")
|
|
||||||
|
|
||||||
class Scene(Emitter):
|
class Scene(Emitter):
|
||||||
"""
|
"""
|
||||||
|
@ -575,6 +577,7 @@ class Scene(Emitter):
|
||||||
"character_state": signal("character_state"),
|
"character_state": signal("character_state"),
|
||||||
"game_loop": async_signals.get("game_loop"),
|
"game_loop": async_signals.get("game_loop"),
|
||||||
"game_loop_start": async_signals.get("game_loop_start"),
|
"game_loop_start": async_signals.get("game_loop_start"),
|
||||||
|
"game_loop_actor_iter": async_signals.get("game_loop_actor_iter"),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.setup_emitter(scene=self)
|
self.setup_emitter(scene=self)
|
||||||
|
@ -1066,7 +1069,9 @@ class Scene(Emitter):
|
||||||
new_message = await narrator.agent.narrate_character(character)
|
new_message = await narrator.agent.narrate_character(character)
|
||||||
elif source == "narrate_query":
|
elif source == "narrate_query":
|
||||||
new_message = await narrator.agent.narrate_query(arg)
|
new_message = await narrator.agent.narrate_query(arg)
|
||||||
|
elif source == "narrate_dialogue":
|
||||||
|
character = self.get_character(arg)
|
||||||
|
new_message = await narrator.agent.narrate_after_dialogue(character)
|
||||||
else:
|
else:
|
||||||
fn = getattr(narrator.agent, source, None)
|
fn = getattr(narrator.agent, source, None)
|
||||||
if not fn:
|
if not fn:
|
||||||
|
@ -1339,6 +1344,10 @@ class Scene(Emitter):
|
||||||
if await command.execute(message):
|
if await command.execute(message):
|
||||||
break
|
break
|
||||||
await self.call_automated_actions()
|
await self.call_automated_actions()
|
||||||
|
|
||||||
|
await self.signals["game_loop_actor_iter"].send(
|
||||||
|
events.GameLoopActorIterEvent(scene=self, event_type="game_loop_actor_iter", actor=actor)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
@ -1350,6 +1359,10 @@ class Scene(Emitter):
|
||||||
emit(
|
emit(
|
||||||
"character", item, character=actor.character
|
"character", item, character=actor.character
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.signals["game_loop_actor_iter"].send(
|
||||||
|
events.GameLoopActorIterEvent(scene=self, event_type="game_loop_actor_iter", actor=actor)
|
||||||
|
)
|
||||||
|
|
||||||
self.emit_status()
|
self.emit_status()
|
||||||
|
|
||||||
|
|
|
@ -303,6 +303,9 @@ def strip_partial_sentences(text:str) -> str:
|
||||||
# Sentence ending characters
|
# Sentence ending characters
|
||||||
sentence_endings = ['.', '!', '?', '"', "*"]
|
sentence_endings = ['.', '!', '?', '"', "*"]
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
# Check if the last character is already a sentence ending
|
# Check if the last character is already a sentence ending
|
||||||
if text[-1] in sentence_endings:
|
if text[-1] in sentence_endings:
|
||||||
return text
|
return text
|
||||||
|
@ -779,7 +782,11 @@ def ensure_dialog_format(line:str, talking_character:str=None) -> str:
|
||||||
lines = []
|
lines = []
|
||||||
|
|
||||||
for _line in line.split("\n"):
|
for _line in line.split("\n"):
|
||||||
_line = ensure_dialog_line_format(_line)
|
try:
|
||||||
|
_line = ensure_dialog_line_format(_line)
|
||||||
|
except Exception as exc:
|
||||||
|
log.error("ensure_dialog_format", msg="Error ensuring dialog line format", line=_line, exc_info=exc)
|
||||||
|
pass
|
||||||
|
|
||||||
lines.append(_line)
|
lines.append(_line)
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,7 @@ export default {
|
||||||
this.state.currentClient = {
|
this.state.currentClient = {
|
||||||
name: 'TextGenWebUI',
|
name: 'TextGenWebUI',
|
||||||
type: 'textgenwebui',
|
type: 'textgenwebui',
|
||||||
apiUrl: 'http://localhost:5000/api',
|
apiUrl: 'http://localhost:5000',
|
||||||
model_name: '',
|
model_name: '',
|
||||||
max_token_length: 4096,
|
max_token_length: 4096,
|
||||||
};
|
};
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
<v-container>
|
<v-container>
|
||||||
<v-row>
|
<v-row>
|
||||||
<v-col cols="6">
|
<v-col cols="6">
|
||||||
<v-select v-model="client.type" :items="['openai', 'textgenwebui']" label="Client Type"></v-select>
|
<v-select v-model="client.type" :items="['openai', 'textgenwebui', 'lmstudio']" label="Client Type"></v-select>
|
||||||
</v-col>
|
</v-col>
|
||||||
<v-col cols="6">
|
<v-col cols="6">
|
||||||
<v-text-field v-model="client.name" label="Client Name"></v-text-field>
|
<v-text-field v-model="client.name" label="Client Name"></v-text-field>
|
||||||
|
@ -17,13 +17,13 @@
|
||||||
</v-row>
|
</v-row>
|
||||||
<v-row>
|
<v-row>
|
||||||
<v-col cols="12">
|
<v-col cols="12">
|
||||||
<v-text-field v-model="client.apiUrl" v-if="client.type === 'textgenwebui'" label="API URL"></v-text-field>
|
<v-text-field v-model="client.apiUrl" v-if="isLocalApiClient(client)" label="API URL"></v-text-field>
|
||||||
<v-select v-model="client.model" v-if="client.type === 'openai'" :items="['gpt-4-1106-preview', 'gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']" label="Model"></v-select>
|
<v-select v-model="client.model" v-if="client.type === 'openai'" :items="['gpt-4-1106-preview', 'gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']" label="Model"></v-select>
|
||||||
</v-col>
|
</v-col>
|
||||||
</v-row>
|
</v-row>
|
||||||
<v-row>
|
<v-row>
|
||||||
<v-col cols="6">
|
<v-col cols="6">
|
||||||
<v-text-field v-model="client.max_token_length" v-if="client.type === 'textgenwebui'" type="number" label="Context Length"></v-text-field>
|
<v-text-field v-model="client.max_token_length" v-if="isLocalApiClient(client)" type="number" label="Context Length"></v-text-field>
|
||||||
</v-col>
|
</v-col>
|
||||||
</v-row>
|
</v-row>
|
||||||
</v-container>
|
</v-container>
|
||||||
|
@ -74,6 +74,9 @@ export default {
|
||||||
save() {
|
save() {
|
||||||
this.$emit('save', this.client); // Emit save event with client object
|
this.$emit('save', this.client); // Emit save event with client object
|
||||||
this.close();
|
this.close();
|
||||||
|
},
|
||||||
|
isLocalApiClient(client) {
|
||||||
|
return client.type === 'textgenwebui' || client.type === 'lmstudio';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
4
templates/llm-prompt/Cat.jinja2
Normal file
4
templates/llm-prompt/Cat.jinja2
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
{{ system_message }}
|
||||||
|
|
||||||
|
### Instruction:
|
||||||
|
{{ set_response(prompt, "\n\n### Response:\n") }}
|
3
templates/llm-prompt/Nous-Capybara.jinja2
Normal file
3
templates/llm-prompt/Nous-Capybara.jinja2
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
USER:
|
||||||
|
{{ system_message }}
|
||||||
|
{{ set_response(prompt, "\nASSISTANT:") }}
|
4
templates/llm-prompt/Psyfighter2.jinja2
Normal file
4
templates/llm-prompt/Psyfighter2.jinja2
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
{{ system_message }}
|
||||||
|
|
||||||
|
### Instruction:
|
||||||
|
{{ set_response(prompt, "\n\n### Response:\n") }}
|
2
templates/llm-prompt/Tess-Medium.jinja2
Normal file
2
templates/llm-prompt/Tess-Medium.jinja2
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
SYSTEM: {{ system_message }}
|
||||||
|
USER: {{ set_response(prompt, "\nASSISTANT: ") }}
|
4
templates/llm-prompt/dolphin-2.2.1-mistral.jinja2
Normal file
4
templates/llm-prompt/dolphin-2.2.1-mistral.jinja2
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
<|im_start|>system
|
||||||
|
{{ system_message }}<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}
|
4
templates/llm-prompt/dolphin-2_2-yi.jinja2
Normal file
4
templates/llm-prompt/dolphin-2_2-yi.jinja2
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
<|im_start|>system
|
||||||
|
{{ system_message }}<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
{{ set_response(prompt, "<|im_end|>\n<|im_start|>assistant\n") }}
|
|
@ -6,5 +6,5 @@ call talemate_env\Scripts\activate
|
||||||
REM use poetry to install dependencies
|
REM use poetry to install dependencies
|
||||||
python -m poetry install
|
python -m poetry install
|
||||||
|
|
||||||
echo Virtual environment re-created.
|
echo Virtual environment updated
|
||||||
pause
|
pause
|
||||||
|
|
Loading…
Add table
Reference in a new issue