mirror of
https://github.com/block/goose.git
synced 2026-04-28 03:29:36 +00:00
feat: add observability plugin system (#227)
Co-authored-by: Michael Neale <michael.neale@gmail.com> Co-authored-by: Lifei Zhou <lifei@squareup.com> Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
parent
70660258e9
commit
d30b524f45
20 changed files with 300 additions and 157 deletions
|
|
@ -52,17 +52,23 @@ just test
|
|||
> [!NOTE]
|
||||
> This integration is experimental and we don't currently have integration tests for it.
|
||||
|
||||
Developers can use locally hosted Langfuse tracing by applying the custom `observe_wrapper` decorator defined in `packages/exchange/src/langfuse_wrapper.py` to functions for automatic integration with Langfuse.
|
||||
Developers can use locally hosted Langfuse tracing by applying the custom `observe_wrapper` decorator defined in `packages/exchange/src/exchange/observers` to functions for automatic integration with Langfuse, and potentially other observability providers in the future.
|
||||
|
||||
- Add an `observers` array to your profile containing `langfuse`.
|
||||
- Run `just langfuse-server` to start your local Langfuse server. It requires Docker.
|
||||
- Go to http://localhost:3000 and log in with the default email/password output by the shell script (values can also be found in the `.env.langfuse.local` file).
|
||||
- Run Goose with the --tracing flag enabled i.e., `goose session start --tracing`
|
||||
- View your traces at http://localhost:3000
|
||||
|
||||
To extend tracing to additional functions, import `from exchange.langfuse_wrapper import observe_wrapper` and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator.
|
||||
`To extend tracing to additional functions, import `from exchange.observers import observe_wrapper` and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator.
|
||||
|
||||
Read more about Langfuse's decorator-based tracing [here](https://langfuse.com/docs/sdk/python/decorators).
|
||||
|
||||
### Other observability plugins
|
||||
|
||||
In case locally hosted Langfuse doesn't fit your needs, you can alternatively use other `observer` telemetry plugins to ingest data with the same interface as the Langfuse integration.
|
||||
To do so, extend `packages/exchange/src/exchange/observers/base.py:Observer` and include the new plugin's path as an entrypoint in `exchange`'s `pyproject.toml`.
|
||||
|
||||
## Exchange
|
||||
|
||||
The lower level generation behind goose is powered by the [`exchange`][ai-exchange] package, also in this repo.
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ Read more about local Langfuse deployments [here](https://langfuse.com/docs/depl
|
|||
|
||||
#### Exchange and Goose integration
|
||||
|
||||
Import `from exchange.langfuse_wrapper import observe_wrapper` and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator.
|
||||
Import `from exchange.observers import observe_wrapper`, include `langfuse` in the `observers` list of your profile, and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator.
|
||||
|
||||
Read more about Langfuse's decorator-based tracing [here](https://langfuse.com/docs/sdk/python/decorators).
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ passive = "exchange.moderators.passive:PassiveModerator"
|
|||
truncate = "exchange.moderators.truncate:ContextTruncate"
|
||||
summarize = "exchange.moderators.summarizer:ContextSummarizer"
|
||||
|
||||
[project.entry-points."exchange.observer"]
|
||||
langfuse = "exchange.observers.langfuse:LangfuseObserver"
|
||||
|
||||
[project.entry-points."metadata.plugins"]
|
||||
ai-exchange = "exchange:module_name"
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ from tiktoken import get_encoding
|
|||
|
||||
from exchange.checkpoint import Checkpoint, CheckpointData
|
||||
from exchange.content import Text, ToolResult, ToolUse
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
from exchange.message import Message
|
||||
from exchange.moderators import Moderator
|
||||
from exchange.moderators.truncate import ContextTruncate
|
||||
from exchange.observers import observe_wrapper
|
||||
from exchange.providers import Provider, Usage
|
||||
from exchange.token_usage_collector import _token_usage_collector
|
||||
from exchange.tool import Tool
|
||||
|
|
|
|||
|
|
@ -1,73 +0,0 @@
|
|||
"""
|
||||
Langfuse Integration Module
|
||||
|
||||
This module provides integration with Langfuse, a tool for monitoring and tracing LLM applications.
|
||||
|
||||
Usage:
|
||||
Import this module to enable Langfuse integration.
|
||||
It automatically checks for Langfuse credentials in the .env.langfuse file and for a running Langfuse server.
|
||||
If these are found, it will set up the necessary client and context for tracing.
|
||||
|
||||
Note:
|
||||
Run setup_langfuse.sh which automates the steps for running local Langfuse.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
from langfuse.decorators import langfuse_context
|
||||
import sys
|
||||
from io import StringIO
|
||||
from functools import cache, wraps
|
||||
|
||||
## These are the default configurations for local Langfuse server
|
||||
## Please refer to .env.langfuse.local file for local langfuse server setup configurations
|
||||
DEFAULT_LOCAL_LANGFUSE_HOST = "http://localhost:3000"
|
||||
DEFAULT_LOCAL_LANGFUSE_PUBLIC_KEY = "publickey-local"
|
||||
DEFAULT_LOCAL_LANGFUSE_SECRET_KEY = "secretkey-local"
|
||||
|
||||
|
||||
@cache
|
||||
def auth_check() -> bool:
|
||||
# Temporarily redirect stdout and stderr to suppress print statements from Langfuse
|
||||
temp_stderr = StringIO()
|
||||
sys.stderr = temp_stderr
|
||||
|
||||
# Set environment variables if not specified
|
||||
os.environ.setdefault("LANGFUSE_PUBLIC_KEY", DEFAULT_LOCAL_LANGFUSE_PUBLIC_KEY)
|
||||
os.environ.setdefault("LANGFUSE_SECRET_KEY", DEFAULT_LOCAL_LANGFUSE_SECRET_KEY)
|
||||
os.environ.setdefault("LANGFUSE_HOST", DEFAULT_LOCAL_LANGFUSE_HOST)
|
||||
|
||||
auth_val = langfuse_context.auth_check()
|
||||
|
||||
# Restore stderr
|
||||
sys.stderr = sys.__stderr__
|
||||
return auth_val
|
||||
|
||||
|
||||
def observe_wrapper(*args, **kwargs) -> Callable: # noqa
|
||||
"""
|
||||
A decorator that wraps a function with Langfuse context observation if credentials are available.
|
||||
|
||||
If Langfuse credentials were found, the function will be wrapped with Langfuse's observe method.
|
||||
Otherwise, the function will be returned as-is.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to pass to langfuse_context.observe.
|
||||
**kwargs: Keyword arguments to pass to langfuse_context.observe.
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped function if credentials are available, otherwise the original function.
|
||||
"""
|
||||
|
||||
def _wrapper(fn: Callable) -> Callable:
|
||||
if auth_check():
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_fn(*fargs, **fkwargs): # noqa
|
||||
return langfuse_context.observe(*args, **kwargs)(fn)(*fargs, **fkwargs)
|
||||
|
||||
return wrapped_fn
|
||||
else:
|
||||
return fn
|
||||
|
||||
return _wrapper
|
||||
20
packages/exchange/src/exchange/observers/__init__.py
Normal file
20
packages/exchange/src/exchange/observers/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from exchange.observers.base import ObserverManager
|
||||
|
||||
|
||||
def observe_wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003
|
||||
"""Decorator to wrap a function with all registered observer plugins, dynamically fetched."""
|
||||
|
||||
def wrapper(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def dynamic_wrapped(*func_args, **func_kwargs) -> Callable: # noqa: ANN002, ANN003
|
||||
wrapped = func
|
||||
for observer in ObserverManager.get_instance()._observers:
|
||||
wrapped = observer.observe_wrapper(*args, **kwargs)(wrapped)
|
||||
return wrapped(*func_args, **func_kwargs)
|
||||
|
||||
return dynamic_wrapped
|
||||
|
||||
return wrapper
|
||||
43
packages/exchange/src/exchange/observers/base.py
Normal file
43
packages/exchange/src/exchange/observers/base.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Type
|
||||
|
||||
|
||||
class Observer(ABC):
|
||||
@abstractmethod
|
||||
def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def observe_wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def finalize(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ObserverManager:
|
||||
_instance = None
|
||||
_observers: list[Observer] = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls: Type["ObserverManager"]) -> "ObserverManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def initialize(self, tracing: bool, observers: list[Observer]) -> None:
|
||||
from exchange.observers.langfuse import LangfuseObserver
|
||||
|
||||
self._observers = observers
|
||||
for observer in self._observers:
|
||||
# LangfuseObserver has special behavior when tracing is _dis_abled.
|
||||
# Consider refactoring to make this less special-casey if that's common.
|
||||
if isinstance(observer, LangfuseObserver) and not tracing:
|
||||
observer.initialize_with_disabled_tracing()
|
||||
elif tracing:
|
||||
observer.initialize()
|
||||
|
||||
def finalize(self) -> None:
|
||||
for observer in self._observers:
|
||||
observer.finalize()
|
||||
100
packages/exchange/src/exchange/observers/langfuse.py
Normal file
100
packages/exchange/src/exchange/observers/langfuse.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""
|
||||
Langfuse Observer
|
||||
|
||||
This observer provides integration with Langfuse, a tool for monitoring and tracing LLM applications.
|
||||
|
||||
Usage:
|
||||
Include "langfuse" in your profile's list of observers to enable Langfuse integration.
|
||||
It automatically checks for Langfuse credentials in the .env.langfuse file and for a running Langfuse server.
|
||||
If these are found, it will set up the necessary client and context for tracing.
|
||||
|
||||
Note:
|
||||
Run setup_langfuse.sh which automates the steps for running local Langfuse.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from functools import cache, wraps
|
||||
from io import StringIO
|
||||
from typing import Callable
|
||||
|
||||
from langfuse.decorators import langfuse_context
|
||||
|
||||
from exchange.observers.base import Observer
|
||||
|
||||
## These are the default configurations for local Langfuse server
|
||||
## Please refer to .env.langfuse.local file for local langfuse server setup configurations
|
||||
DEFAULT_LOCAL_LANGFUSE_HOST = "http://localhost:3000"
|
||||
DEFAULT_LOCAL_LANGFUSE_PUBLIC_KEY = "publickey-local"
|
||||
DEFAULT_LOCAL_LANGFUSE_SECRET_KEY = "secretkey-local"
|
||||
|
||||
|
||||
@cache
|
||||
def auth_check() -> bool:
|
||||
# Temporarily redirect stdout and stderr to suppress print statements from Langfuse
|
||||
temp_stderr = StringIO()
|
||||
sys.stderr = temp_stderr
|
||||
|
||||
# Set environment variables if not specified
|
||||
os.environ.setdefault("LANGFUSE_PUBLIC_KEY", DEFAULT_LOCAL_LANGFUSE_PUBLIC_KEY)
|
||||
os.environ.setdefault("LANGFUSE_SECRET_KEY", DEFAULT_LOCAL_LANGFUSE_SECRET_KEY)
|
||||
os.environ.setdefault("LANGFUSE_HOST", DEFAULT_LOCAL_LANGFUSE_HOST)
|
||||
|
||||
auth_val = langfuse_context.auth_check()
|
||||
|
||||
# Restore stderr
|
||||
sys.stderr = sys.__stderr__
|
||||
return auth_val
|
||||
|
||||
|
||||
class LangfuseObserver(Observer):
|
||||
def initialize(self) -> None:
|
||||
langfuse_auth = auth_check()
|
||||
if langfuse_auth:
|
||||
print("Local Langfuse initialized. View your traces at http://localhost:3000")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"You passed --tracing, but a Langfuse object was not found in the current context. "
|
||||
"Please initialize the local Langfuse server and restart Goose."
|
||||
)
|
||||
|
||||
langfuse_context.configure(enabled=True)
|
||||
self.tracing = True
|
||||
|
||||
def initialize_with_disabled_tracing(self) -> None:
|
||||
logging.getLogger("langfuse").setLevel(logging.ERROR)
|
||||
langfuse_context.configure(enabled=False)
|
||||
self.tracing = False
|
||||
|
||||
def session_id_wrapper(self, func: Callable, session_id: str) -> Callable:
|
||||
@wraps(func) # This will preserve the metadata of 'func'
|
||||
def wrapper(*args, **kwargs) -> Callable: # noqa: ANN002, ANN003
|
||||
langfuse_context.update_current_trace(session_id=session_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def observe_wrapper(self, *args, **kwargs) -> Callable: # noqa: ANN002, ANN003
|
||||
def _wrapper(fn: Callable) -> Callable:
|
||||
if self.tracing and auth_check():
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_fn(*fargs, **fkwargs) -> Callable: # noqa: ANN002, ANN003
|
||||
# group all traces under the same session
|
||||
if "session_id" in kwargs:
|
||||
session_id_function = kwargs.pop("session_id")
|
||||
session_id_value = session_id_function(fargs[0])
|
||||
modified_fn = self.session_id_wrapper(fn, session_id_value)
|
||||
return langfuse_context.observe(*args, **kwargs)(modified_fn)(*fargs, **fkwargs)
|
||||
else:
|
||||
return langfuse_context.observe(*args, **kwargs)(fn)(*fargs, **fkwargs)
|
||||
|
||||
return wrapped_fn
|
||||
else:
|
||||
return fn
|
||||
|
||||
return _wrapper
|
||||
|
||||
def finalize(self) -> None:
|
||||
langfuse_context.flush()
|
||||
|
|
@ -7,7 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse
|
|||
from exchange.providers.base import Provider, Usage
|
||||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from exchange.providers.utils import retry_if_status, raise_for_status
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
from exchange.observers import observe_wrapper
|
||||
|
||||
ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages"
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from exchange.providers import Provider, Usage
|
|||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from exchange.providers.utils import raise_for_status, retry_if_status
|
||||
from exchange.tool import Tool
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
from exchange.observers import observe_wrapper
|
||||
|
||||
SERVICE = "bedrock-runtime"
|
||||
UTC = timezone.utc
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from exchange.providers.utils import (
|
|||
tools_to_openai_spec,
|
||||
)
|
||||
from exchange.tool import Tool
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
from exchange.observers import observe_wrapper
|
||||
|
||||
retry_procedure = retry(
|
||||
wait=wait_fixed(2),
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from exchange.content import Text, ToolResult, ToolUse
|
|||
from exchange.providers.base import Provider, Usage
|
||||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from exchange.providers.utils import raise_for_status, retry_if_status, encode_image
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
from exchange.observers import observe_wrapper
|
||||
|
||||
|
||||
GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
from exchange.observers import observe_wrapper
|
||||
import httpx
|
||||
|
||||
from exchange.message import Message
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from exchange.providers.utils import (
|
|||
from exchange.tool import Tool
|
||||
from tenacity import retry, wait_fixed, stop_after_attempt
|
||||
from exchange.providers.utils import retry_if_status
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
from exchange.observers import observe_wrapper
|
||||
|
||||
OPENAI_HOST = "https://api.openai.com/"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,48 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_langfuse_context():
|
||||
with patch("exchange.langfuse_wrapper.langfuse_context") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@patch("exchange.langfuse_wrapper.auth_check")
|
||||
def test_function_is_wrapped(mock_auth_check, mock_langfuse_context):
|
||||
mock_observe = MagicMock(side_effect=lambda *args, **kwargs: lambda fn: fn)
|
||||
mock_auth_check.return_value = True
|
||||
mock_langfuse_context.observe = mock_observe
|
||||
|
||||
def original_function(x: int, y: int) -> int:
|
||||
return x + y
|
||||
|
||||
# test function before we decorate it with
|
||||
# @observe_wrapper("arg1", kwarg1="kwarg1")
|
||||
assert not hasattr(original_function, "__wrapped__")
|
||||
|
||||
# ensure we args get passed along (e.g. @observe(capture_input=False, capture_output=False))
|
||||
decorated_function = observe_wrapper("arg1", kwarg1="kwarg1")(original_function)
|
||||
assert hasattr(decorated_function, "__wrapped__")
|
||||
assert decorated_function.__wrapped__ is original_function, "Function is not properly wrapped"
|
||||
|
||||
assert decorated_function(2, 3) == 5
|
||||
mock_observe.assert_called_once()
|
||||
mock_observe.assert_called_with("arg1", kwarg1="kwarg1")
|
||||
|
||||
|
||||
@patch("exchange.langfuse_wrapper.auth_check")
|
||||
def test_function_is_not_wrapped(mock_auth_check, mock_langfuse_context):
|
||||
mock_observe = MagicMock(return_value=lambda f: f)
|
||||
mock_auth_check.return_value = False
|
||||
mock_langfuse_context.observe = mock_observe
|
||||
|
||||
@observe_wrapper("arg1", kwarg1="kwarg1")
|
||||
def hello() -> str:
|
||||
return "Hello"
|
||||
|
||||
assert not hasattr(hello, "__wrapped__")
|
||||
assert hello() == "Hello"
|
||||
|
||||
mock_observe.assert_not_called()
|
||||
61
packages/exchange/tests/test_observer.py
Normal file
61
packages/exchange/tests/test_observer.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from exchange.observers import ObserverManager, observe_wrapper
|
||||
from exchange.observers.base import Observer
|
||||
|
||||
|
||||
class MockObserver(Observer):
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
self.args = None
|
||||
self.kwargs = None
|
||||
self.finalized = False
|
||||
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def observe_wrapper(self, *args, **kwargs):
|
||||
def wrapper(func):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def finalize(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_wrapper_is_invoked():
|
||||
manager = ObserverManager.get_instance()
|
||||
mock_observer = MockObserver()
|
||||
manager.initialize(True, [mock_observer])
|
||||
|
||||
@observe_wrapper("arg0", arg1="arg2")
|
||||
def wrapped(x: int, y: int) -> int:
|
||||
return x + y
|
||||
|
||||
# code in decorator hasn't run yet
|
||||
assert mock_observer.args is None
|
||||
assert mock_observer.kwargs is None
|
||||
|
||||
ret_val = wrapped(2, 3)
|
||||
assert ret_val == 5
|
||||
|
||||
# decorator has been run since `wrapped` was called
|
||||
assert mock_observer.args == ("arg0",)
|
||||
assert mock_observer.kwargs == {"arg1": "arg2"}
|
||||
|
||||
|
||||
def test_multiple_wrappers():
|
||||
manager = ObserverManager.get_instance()
|
||||
mock_observer_1 = MockObserver()
|
||||
mock_observer_2 = MockObserver()
|
||||
manager.initialize(True, [mock_observer_1, mock_observer_2])
|
||||
|
||||
@observe_wrapper("arg0")
|
||||
def wrapped(x: int, y: int) -> int:
|
||||
return x + y
|
||||
|
||||
wrapped(2, 3)
|
||||
|
||||
assert mock_observer_1.args == ("arg0",)
|
||||
assert mock_observer_2.args == ("arg0",)
|
||||
|
|
@ -1,12 +1,10 @@
|
|||
from datetime import datetime
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from exchange import Message, Text, ToolResult, ToolUse
|
||||
from exchange.langfuse_wrapper import auth_check, observe_wrapper
|
||||
from langfuse.decorators import langfuse_context
|
||||
from exchange.observers import ObserverManager, observe_wrapper
|
||||
from rich import print
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
|
@ -79,23 +77,17 @@ class Session:
|
|||
self.notifier = SessionNotifier(self.status_indicator)
|
||||
self.has_plan = plan is not None
|
||||
self.tracing = tracing
|
||||
if not tracing:
|
||||
logging.getLogger("langfuse").setLevel(logging.ERROR)
|
||||
else:
|
||||
langfuse_auth = auth_check()
|
||||
if langfuse_auth:
|
||||
print("Local Langfuse initialized. View your traces at http://localhost:3000")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"You passed --tracing, but a Langfuse object was not found in the current context. "
|
||||
"Please initialize the local Langfuse server and restart Goose."
|
||||
)
|
||||
if self.tracing:
|
||||
langfuse_context.configure(enabled=tracing)
|
||||
|
||||
self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier)
|
||||
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)
|
||||
|
||||
all_observers = load_plugins(group="exchange.observer")
|
||||
profile_observer_names = load_profile(profile).observers
|
||||
observers_to_init = [all_observers[o.name]() for o in profile_observer_names if o.name in all_observers]
|
||||
|
||||
self.observer_manager = ObserverManager.get_instance()
|
||||
self.observer_manager.initialize(tracing=tracing, observers=observers_to_init)
|
||||
|
||||
self.exchange.messages.extend(self._get_initial_messages())
|
||||
|
||||
if len(self.exchange.messages) == 0 and plan:
|
||||
|
|
@ -103,6 +95,10 @@ class Session:
|
|||
|
||||
self.prompt_session = GoosePromptSession()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if hasattr(self, "observer_manager"):
|
||||
self.observer_manager.finalize()
|
||||
|
||||
def _get_initial_messages(self) -> list[Message]:
|
||||
messages = self.load_session()
|
||||
|
||||
|
|
@ -211,12 +207,9 @@ class Session:
|
|||
time_end = datetime.now()
|
||||
self._log_cost(start_time=time_start, end_time=time_end)
|
||||
|
||||
@observe_wrapper()
|
||||
@observe_wrapper(session_id=lambda instance: instance.name)
|
||||
def reply(self) -> None:
|
||||
"""Reply to the last user message, calling tools as needed"""
|
||||
# group all traces under the same session
|
||||
langfuse_context.update_current_trace(session_id=self.name)
|
||||
|
||||
# These are the *raw* messages, before the moderator rewrites things
|
||||
committed = [self.exchange.messages[-1]]
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,13 @@ class ToolkitSpec:
|
|||
requires: Mapping[str, str] = field(factory=dict)
|
||||
|
||||
|
||||
@define
|
||||
class ObserverSpec:
|
||||
"""Configuration for an Observer (telemetry plugin)"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
@define
|
||||
class Profile:
|
||||
"""The configuration for a run of goose"""
|
||||
|
|
@ -22,6 +29,7 @@ class Profile:
|
|||
accelerator: str
|
||||
moderator: str
|
||||
toolkits: list[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec))
|
||||
observers: list[ObserverSpec] = field(factory=list, converter=ensure_list(ObserverSpec))
|
||||
|
||||
@toolkits.validator
|
||||
def check_toolkit_requirements(self, _: type["ToolkitSpec"], toolkits: list[ToolkitSpec]) -> None:
|
||||
|
|
@ -40,8 +48,13 @@ class Profile:
|
|||
return asdict(self)
|
||||
|
||||
def profile_info(self) -> str:
|
||||
tookit_names = [toolkit.name for toolkit in self.toolkits]
|
||||
return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}"
|
||||
toolkit_names = [toolkit.name for toolkit in self.toolkits]
|
||||
observer_names = [observer.name for observer in self.observers]
|
||||
return (
|
||||
f"provider:{self.provider}, processor:{self.processor} "
|
||||
f"toolkits: {', '.join(toolkit_names)} "
|
||||
f"observers: {', '.join(observer_names)}"
|
||||
)
|
||||
|
||||
|
||||
def default_profile(provider: str, processor: str, accelerator: str, **kwargs: dict[str, any]) -> Profile:
|
||||
|
|
@ -55,4 +68,5 @@ def default_profile(provider: str, processor: str, accelerator: str, **kwargs: d
|
|||
accelerator=accelerator,
|
||||
moderator="synopsis",
|
||||
toolkits=[ToolkitSpec("synopsis")],
|
||||
observers=[ObserverSpec("langfuse")],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from datetime import datetime
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from exchange import Message, ToolResult, ToolUse
|
||||
from exchange.observers import ObserverManager
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt
|
||||
from goose.cli.prompt.user_input import PromptAction, UserInput
|
||||
|
|
@ -260,3 +261,22 @@ def test_prompt_overwrite_session(session_factory):
|
|||
choice="r",
|
||||
expected_messages=[Message.user(text="duck duck"), Message.user(text="goose")],
|
||||
)
|
||||
|
||||
|
||||
def test_observer_plugin_called(create_session_with_mock_configs):
|
||||
observer_mock = MagicMock()
|
||||
observe_wrapper_mock = MagicMock()
|
||||
observer_mock.observe_wrapper = observe_wrapper_mock
|
||||
|
||||
observer_manager_mock = MagicMock(spec=ObserverManager)
|
||||
observer_manager_mock._observers = [observer_mock]
|
||||
|
||||
with patch("exchange.observers.ObserverManager.get_instance", return_value=observer_manager_mock), patch(
|
||||
"exchange.Exchange.generate", return_value=Message.assistant("test response")
|
||||
):
|
||||
session = create_session_with_mock_configs({"name": SESSION_NAME})
|
||||
|
||||
session.exchange.messages.append(Message.user("hi"))
|
||||
session.reply()
|
||||
|
||||
observe_wrapper_mock.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from goose.profile import ToolkitSpec
|
||||
from goose.profile import ToolkitSpec, ObserverSpec
|
||||
|
||||
|
||||
def test_profile_info(profile_factory):
|
||||
|
|
@ -7,6 +7,10 @@ def test_profile_info(profile_factory):
|
|||
"provider": "provider",
|
||||
"processor": "processor",
|
||||
"toolkits": [ToolkitSpec("developer"), ToolkitSpec("github")],
|
||||
"observers": [ObserverSpec(name="test.plugin")],
|
||||
}
|
||||
)
|
||||
assert profile.profile_info() == "provider:provider, processor:processor toolkits: developer, github"
|
||||
assert (
|
||||
profile.profile_info()
|
||||
== "provider:provider, processor:processor toolkits: developer, github observers: test.plugin"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue