mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-24 13:43:45 +00:00
333 lines
11 KiB
Python
333 lines
11 KiB
Python
# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
|
|
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
|
|
|
|
import numpy as np
|
|
from datasets import Dataset, load_dataset
|
|
|
|
from camel.agents import ChatAgent
|
|
from camel.benchmarks import BaseBenchmark
|
|
from camel.logger import get_logger
|
|
from camel.retrievers import AutoRetriever
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class RagasFields:
|
|
r"""Constants for RAGAS evaluation field names."""
|
|
|
|
INPUT_CONTEXT = "contexts"
|
|
INPUT_QUESTION = "question"
|
|
INPUT_ANSWER = "answer"
|
|
|
|
|
|
def annotate_dataset(
|
|
dataset: Dataset,
|
|
context_call: Optional[Callable[[Dict[str, Any]], List[str]]],
|
|
answer_call: Optional[Callable[[Dict[str, Any]], str]],
|
|
) -> Dataset:
|
|
r"""Annotate the dataset by adding context and answers using the provided
|
|
functions.
|
|
|
|
Args:
|
|
dataset (Dataset): The input dataset to annotate.
|
|
context_call (Optional[Callable[[Dict[str, Any]], List[str]]]):
|
|
Function to generate context for each example.
|
|
answer_call (Optional[Callable[[Dict[str, Any]], str]]): Function to
|
|
generate answer for each example.
|
|
|
|
Returns:
|
|
Dataset: The annotated dataset with added contexts and/or answers.
|
|
"""
|
|
|
|
def process_example(example: Dict[str, Any]) -> Dict[str, Any]:
|
|
if context_call:
|
|
example["contexts"] = context_call(example)
|
|
if answer_call:
|
|
example["answer"] = answer_call(example)
|
|
return example
|
|
|
|
return dataset.map(process_example)
|
|
|
|
|
|
def rmse(
|
|
input_trues: Sequence[float],
|
|
input_preds: Sequence[float],
|
|
) -> Optional[float]:
|
|
r"""Calculate Root Mean Squared Error (RMSE).
|
|
|
|
Args:
|
|
input_trues (Sequence[float]): Ground truth values.
|
|
input_preds (Sequence[float]): Predicted values.
|
|
|
|
Returns:
|
|
Optional[float]: RMSE value, or None if inputs have different lengths.
|
|
"""
|
|
if len(input_trues) != len(input_preds):
|
|
logger.warning("Input lengths mismatch in RMSE calculation")
|
|
return None
|
|
|
|
trues = np.array(input_trues)
|
|
preds = np.array(input_preds, dtype=float)
|
|
|
|
# Ignore NaN values in predictions
|
|
eval_idx = ~np.isnan(preds)
|
|
if not np.any(eval_idx):
|
|
logger.warning("No valid predictions for RMSE calculation")
|
|
return None
|
|
|
|
trues = trues[eval_idx]
|
|
preds = preds[eval_idx]
|
|
|
|
return float(np.sqrt(np.mean((preds - trues) ** 2)))
|
|
|
|
|
|
def auroc(trues: Sequence[bool], preds: Sequence[float]) -> float:
|
|
r"""Calculate Area Under Receiver Operating Characteristic Curve (AUROC).
|
|
|
|
Args:
|
|
trues (Sequence[bool]): Ground truth binary values.
|
|
preds (Sequence[float]): Predicted probability values.
|
|
|
|
Returns:
|
|
float: AUROC score.
|
|
"""
|
|
from sklearn.metrics import roc_auc_score # type: ignore[import-untyped]
|
|
|
|
eval_idx = ~np.isnan(preds)
|
|
if not np.any(eval_idx):
|
|
logger.warning("No valid predictions for AUROC calculation")
|
|
return 0.5 # Return random classifier score
|
|
|
|
return float(
|
|
roc_auc_score(np.array(trues)[eval_idx], np.array(preds)[eval_idx])
|
|
)
|
|
|
|
|
|
def ragas_calculate_metrics(
|
|
dataset: Dataset,
|
|
pred_context_relevance_field: Optional[str],
|
|
pred_faithfulness_field: Optional[str],
|
|
metrics_to_evaluate: Optional[List[str]] = None,
|
|
ground_truth_context_relevance_field: str = "relevance_score",
|
|
ground_truth_faithfulness_field: str = "adherence_score",
|
|
) -> Dict[str, Optional[float]]:
|
|
r"""Calculate RAGAS evaluation metrics.
|
|
|
|
Args:
|
|
dataset (Dataset): The dataset containing predictions and ground truth.
|
|
pred_context_relevance_field (Optional[str]): Field name for predicted
|
|
context relevance.
|
|
pred_faithfulness_field (Optional[str]): Field name for predicted
|
|
faithfulness.
|
|
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
|
|
ground_truth_context_relevance_field (str): Field name for ground truth
|
|
relevance.
|
|
ground_truth_faithfulness_field (str): Field name for ground truth
|
|
adherence.
|
|
|
|
Returns:
|
|
Dict[str, Optional[float]]: Dictionary of calculated metrics.
|
|
"""
|
|
metrics_to_evaluate = metrics_to_evaluate or [
|
|
"context_relevancy",
|
|
"faithfulness",
|
|
]
|
|
calculated_metrics: Dict[str, Optional[float]] = {}
|
|
|
|
if (
|
|
"context_relevancy" in metrics_to_evaluate
|
|
and pred_context_relevance_field
|
|
):
|
|
trues_relevance = dataset[ground_truth_context_relevance_field]
|
|
preds_relevance = dataset[pred_context_relevance_field]
|
|
calculated_metrics["relevance_rmse"] = rmse(
|
|
trues_relevance, preds_relevance
|
|
)
|
|
|
|
if "faithfulness" in metrics_to_evaluate and pred_faithfulness_field:
|
|
trues_hallucination = ~np.array(
|
|
dataset[ground_truth_faithfulness_field]
|
|
)
|
|
preds_hallucination = 1 - np.array(
|
|
dataset[pred_faithfulness_field], dtype=float
|
|
)
|
|
calculated_metrics["hallucination_auroc"] = auroc(
|
|
trues_hallucination.tolist(), preds_hallucination.tolist()
|
|
)
|
|
|
|
return calculated_metrics
|
|
|
|
|
|
def ragas_evaluate_dataset(
|
|
dataset: Dataset,
|
|
contexts_field_name: Optional[str],
|
|
answer_field_name: Optional[str],
|
|
metrics_to_evaluate: Optional[List[str]] = None,
|
|
) -> Dataset:
|
|
r"""Evaluate the dataset using RAGAS metrics.
|
|
|
|
Args:
|
|
dataset (Dataset): Input dataset to evaluate.
|
|
contexts_field_name (Optional[str]): Field name containing contexts.
|
|
answer_field_name (Optional[str]): Field name containing answers.
|
|
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
|
|
|
|
Returns:
|
|
Dataset: Dataset with added evaluation metrics.
|
|
"""
|
|
from ragas import evaluate # type: ignore[import]
|
|
from ragas.metrics import ( # type: ignore[import]
|
|
context_relevancy,
|
|
faithfulness,
|
|
)
|
|
|
|
metrics_to_evaluate = metrics_to_evaluate or [
|
|
"context_relevancy",
|
|
"faithfulness",
|
|
]
|
|
|
|
# Rename fields if necessary
|
|
if (
|
|
contexts_field_name
|
|
and contexts_field_name != RagasFields.INPUT_CONTEXT
|
|
):
|
|
dataset = dataset.rename_column(
|
|
contexts_field_name, RagasFields.INPUT_CONTEXT
|
|
)
|
|
if answer_field_name and answer_field_name != RagasFields.INPUT_ANSWER:
|
|
dataset = dataset.rename_column(
|
|
answer_field_name, RagasFields.INPUT_ANSWER
|
|
)
|
|
|
|
metrics = []
|
|
if "context_relevancy" in metrics_to_evaluate:
|
|
metrics.append(context_relevancy)
|
|
if "faithfulness" in metrics_to_evaluate:
|
|
metrics.append(faithfulness)
|
|
|
|
ragas_result = evaluate(dataset, metrics=metrics)
|
|
return Dataset.from_pandas(ragas_result.to_pandas())
|
|
|
|
|
|
class RAGBenchBenchmark(BaseBenchmark):
|
|
r"""RAGBench Benchmark for evaluating RAG performance.
|
|
|
|
This benchmark uses the rungalileo/ragbench dataset to evaluate
|
|
retrieval-augmented generation (RAG) systems. It measures context
|
|
relevancy and faithfulness metrics as described in
|
|
https://arxiv.org/abs/2407.11005.
|
|
|
|
Args:
|
|
processes (int, optional): Number of processes for parallel processing.
|
|
subset (str, optional): Dataset subset to use (e.g., "hotpotqa").
|
|
split (str, optional): Dataset split to use (e.g., "test").
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
processes: int = 1,
|
|
subset: Literal[
|
|
"covidqa",
|
|
"cuad",
|
|
"delucionqa",
|
|
"emanual",
|
|
"expertqa",
|
|
"finqa",
|
|
"hagrid",
|
|
"hotpotqa",
|
|
"msmarco",
|
|
"pubmedqa",
|
|
"tatqa",
|
|
"techqa",
|
|
] = "hotpotqa",
|
|
split: Literal["train", "test", "validation"] = "test",
|
|
) -> None:
|
|
super().__init__("ragbench", "rag_bench", "", processes)
|
|
self.subset = subset
|
|
self.split = split
|
|
self.dataset: Optional[Dataset] = None
|
|
|
|
def download(self):
|
|
r"""Download the RAGBench dataset."""
|
|
try:
|
|
self.dataset = load_dataset(
|
|
"rungalileo/ragbench", self.subset, split=self.split
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to download dataset: {e}")
|
|
raise
|
|
|
|
def load(self, force_download: bool = False):
|
|
r"""Load the RAGBench dataset.
|
|
|
|
Args:
|
|
force_download (bool, optional): Whether to force download the
|
|
data.
|
|
"""
|
|
if force_download or self.dataset is None:
|
|
logger.info(
|
|
"%s dataset",
|
|
"Force downloading" if force_download else "Loading",
|
|
)
|
|
self.download()
|
|
|
|
def run( # type: ignore[override, return]
|
|
self,
|
|
agent: ChatAgent,
|
|
auto_retriever: AutoRetriever,
|
|
) -> Dict[str, Optional[float]]:
|
|
r"""Run the benchmark evaluation.
|
|
|
|
Args:
|
|
agent (ChatAgent): Chat agent for generating answers.
|
|
auto_retriever (AutoRetriever): Retriever for finding relevant
|
|
contexts.
|
|
|
|
Returns:
|
|
Dict[str, Optional[float]]: Dictionary of evaluation metrics.
|
|
"""
|
|
|
|
def context_call(example):
|
|
retrieved_info = auto_retriever.run_vector_retriever(
|
|
query=example['question'],
|
|
contents=example['documents'],
|
|
top_k=1,
|
|
return_detailed_info=True,
|
|
similarity_threshold=0.5,
|
|
)
|
|
return [c['text'] for c in retrieved_info['Retrieved Context']]
|
|
|
|
def answer_call(example: Dict[str, Any]) -> str:
|
|
user_msg = str(example)
|
|
assistant_response = agent.step(user_msg)
|
|
return assistant_response.msg.content
|
|
|
|
# Annotate the dataset
|
|
annotated_ds = annotate_dataset(
|
|
self.dataset, context_call, answer_call
|
|
)
|
|
evaluated_ds = ragas_evaluate_dataset(
|
|
annotated_ds,
|
|
contexts_field_name="contexts",
|
|
answer_field_name="answer",
|
|
metrics_to_evaluate=["context_relevancy", "faithfulness"],
|
|
)
|
|
|
|
return ragas_calculate_metrics(
|
|
evaluated_ds,
|
|
pred_context_relevance_field="context_relevancy",
|
|
pred_faithfulness_field="faithfulness",
|
|
)
|