open-notebook/open_notebook/utils/embedding.py
unendless314 6aabacfca6
feat: use token-based sizing for embedding chunking (#749)
* feat: make chunk sizing token-based with 512-token default

* fix: defer embedding debug token metrics

* chore: lower default chunk size to 400 tokens and document rationale

The previous 512-token default matched exactly the context window of
BERT-family embedders like mxbai-embed-large, leaving no margin for:
- tokenizer mismatch between our o200k_base measurement and the
  embedder's own WordPiece tokenizer
- occasional splitter overshoot (RecursiveCharacterTextSplitter can
  emit chunks slightly above chunk_size when separators are sparse)
- special tokens ([CLS], [SEP]) that consume context-window budget

400 tokens keeps ~20% headroom below 512 while still being a large
improvement over the old character-based default for most content.
Users with larger-context embedders can raise OPEN_NOTEBOOK_CHUNK_SIZE
via env var. Also adds a CHANGELOG entry for the full PR behavior
change.

* chore: move chunking changelog entry under 1.8.5

Target release is 1.8.5 — moving the Changed section out of Unreleased.

---------

Co-authored-by: Luis Novo <lfnovo@gmail.com>
2026-04-19 13:49:09 -03:00

251 lines
8.4 KiB
Python

"""
Unified embedding utilities for Open Notebook.
Provides centralized embedding generation with support for:
- Single text embedding (with automatic chunking and mean pooling for large texts)
- Batch text embedding (multiple texts with automatic batching)
- Mean pooling for combining multiple embeddings into one
All embedding operations in the application should use these functions
to ensure consistent behavior and proper handling of large content.
"""
import asyncio
from typing import TYPE_CHECKING, List, Optional
import numpy as np
from loguru import logger
from .chunking import CHUNK_SIZE, ContentType, chunk_text
from .token_utils import token_count
EMBEDDING_BATCH_SIZE = 50
EMBEDDING_MAX_RETRIES = 3
EMBEDDING_RETRY_DELAY = 2 # seconds
# Lazy import to avoid circular dependency:
# utils -> embedding -> models -> key_provider -> provider_config -> utils
if TYPE_CHECKING:
from open_notebook.ai.models import ModelManager
async def mean_pool_embeddings(embeddings: List[List[float]]) -> List[float]:
"""
Combine multiple embeddings into a single embedding using mean pooling.
Algorithm:
1. Normalize each embedding to unit length
2. Compute element-wise mean
3. Normalize the result to unit length
This approach ensures the final embedding has the same properties as
individual embeddings (unit length) regardless of input count.
Args:
embeddings: List of embedding vectors (each is a list of floats)
Returns:
Single embedding vector (mean pooled and normalized)
Raises:
ValueError: If embeddings list is empty or embeddings have different dimensions
"""
if not embeddings:
raise ValueError("Cannot mean pool empty list of embeddings")
if len(embeddings) == 1:
# Single embedding - just normalize and return
arr = np.array(embeddings[0], dtype=np.float64)
norm = np.linalg.norm(arr)
if norm > 0:
arr = arr / norm
return arr.tolist()
# Convert to numpy array
arr = np.array(embeddings, dtype=np.float64)
# Verify all embeddings have same dimension
if arr.ndim != 2:
raise ValueError(f"Expected 2D array, got shape {arr.shape}")
# Normalize each embedding to unit length
norms = np.linalg.norm(arr, axis=1, keepdims=True)
# Avoid division by zero
norms = np.where(norms > 0, norms, 1.0)
normalized = arr / norms
# Compute mean
mean = np.mean(normalized, axis=0)
# Normalize the result
mean_norm = np.linalg.norm(mean)
if mean_norm > 0:
mean = mean / mean_norm
return mean.tolist()
async def generate_embeddings(
texts: List[str], command_id: Optional[str] = None
) -> List[List[float]]:
"""
Generate embeddings for multiple texts with automatic batching and retry.
Texts are split into batches of EMBEDDING_BATCH_SIZE to avoid exceeding
provider payload limits. Each batch is retried up to EMBEDDING_MAX_RETRIES
times on transient failures.
Args:
texts: List of text strings to embed
command_id: Optional command ID for error logging context
Returns:
List of embedding vectors, one per input text
Raises:
ValueError: If no embedding model is configured
RuntimeError: If embedding generation fails
"""
if not texts:
return []
# Lazy import to avoid circular dependency
from open_notebook.ai.models import model_manager
embedding_model = await model_manager.get_embedding_model()
if not embedding_model:
raise ValueError(
"No embedding model configured. Please configure one in the Models section."
)
model_name = getattr(embedding_model, "model_name", "unknown")
# Log text sizes for debugging
metrics: tuple[int, int, int, int] | None = None
def _get_size_metrics() -> tuple[int, int, int, int]:
nonlocal metrics
if metrics is None:
token_sizes = [token_count(t) for t in texts]
metrics = (
min(token_sizes),
max(token_sizes),
sum(token_sizes),
sum(len(t) for t in texts),
)
return metrics
logger.opt(lazy=True).debug(
"Generating embeddings for {} texts "
"(tokens: min={}, max={}, total={}; chars: total={})",
lambda: len(texts),
lambda: _get_size_metrics()[0],
lambda: _get_size_metrics()[1],
lambda: _get_size_metrics()[2],
lambda: _get_size_metrics()[3],
)
all_embeddings: List[List[float]] = []
total_batches = (len(texts) + EMBEDDING_BATCH_SIZE - 1) // EMBEDDING_BATCH_SIZE
for batch_idx in range(total_batches):
start = batch_idx * EMBEDDING_BATCH_SIZE
end = start + EMBEDDING_BATCH_SIZE
batch = texts[start:end]
for attempt in range(1, EMBEDDING_MAX_RETRIES + 1):
try:
batch_embeddings = await embedding_model.aembed(batch)
all_embeddings.extend(batch_embeddings)
break
except Exception as e:
cmd_context = f" (command: {command_id})" if command_id else ""
if attempt < EMBEDDING_MAX_RETRIES:
logger.debug(
f"Embedding batch {batch_idx + 1}/{total_batches} "
f"attempt {attempt}/{EMBEDDING_MAX_RETRIES} failed "
f"using model '{model_name}'{cmd_context}: {e}. Retrying..."
)
await asyncio.sleep(EMBEDDING_RETRY_DELAY)
else:
logger.debug(
f"Embedding batch {batch_idx + 1}/{total_batches} "
f"failed after {EMBEDDING_MAX_RETRIES} attempts "
f"using model '{model_name}'{cmd_context}: {e}"
)
raise RuntimeError(
f"Failed to generate embeddings using model '{model_name}' "
f"(batch {batch_idx + 1}/{total_batches}, "
f"{len(batch)} texts): {e}"
) from e
logger.debug(f"Generated {len(all_embeddings)} embeddings in {total_batches} batch(es)")
return all_embeddings
async def generate_embedding(
text: str,
content_type: Optional[ContentType] = None,
file_path: Optional[str] = None,
command_id: Optional[str] = None,
) -> List[float]:
"""
Generate a single embedding for text, handling large content via chunking and mean pooling.
For short text (<= CHUNK_SIZE tokens):
- Embeds directly and returns the embedding
For long text (> CHUNK_SIZE tokens):
- Chunks the text using appropriate splitter for content type
- Embeds all chunks in batches
- Combines embeddings via mean pooling
Args:
text: The text to embed
content_type: Optional explicit content type for chunking
file_path: Optional file path for content type detection
command_id: Optional command ID for error logging context
Returns:
Single embedding vector (list of floats)
Raises:
ValueError: If text is empty or no embedding model configured
RuntimeError: If embedding generation fails
"""
if not text or not text.strip():
raise ValueError("Cannot generate embedding for empty text")
text = text.strip()
text_tokens = token_count(text)
# Check if chunking is needed
if text_tokens <= CHUNK_SIZE:
# Short text - embed directly
logger.debug(f"Embedding short text ({text_tokens} tokens) directly")
embeddings = await generate_embeddings([text], command_id=command_id)
return embeddings[0]
# Long text - chunk and mean pool
logger.debug(f"Text exceeds chunk size ({text_tokens} tokens), chunking...")
chunks = chunk_text(text, content_type=content_type, file_path=file_path)
if not chunks:
raise ValueError("Text chunking produced no chunks")
if len(chunks) == 1:
# Single chunk after splitting
embeddings = await generate_embeddings(chunks, command_id=command_id)
return embeddings[0]
logger.debug(f"Embedding {len(chunks)} chunks and mean pooling")
# Embed all chunks in batches
embeddings = await generate_embeddings(chunks, command_id=command_id)
# Mean pool to get single embedding
pooled = await mean_pool_embeddings(embeddings)
logger.debug(f"Mean pooled {len(embeddings)} embeddings into single vector")
return pooled