mirror of
https://github.com/block/goose.git
synced 2026-04-28 03:29:36 +00:00
Add vision/image support for local inference models (#8442)
Some checks failed
Canary / Prepare Version (push) Waiting to run
Canary / build-cli (push) Blocked by required conditions
Canary / Upload Install Script (push) Blocked by required conditions
Canary / bundle-desktop (push) Blocked by required conditions
Canary / bundle-desktop-intel (push) Blocked by required conditions
Canary / bundle-desktop-linux (push) Blocked by required conditions
Canary / bundle-desktop-windows (push) Blocked by required conditions
Canary / Release (push) Blocked by required conditions
Unused Dependencies / machete (push) Waiting to run
CI / changes (push) Waiting to run
CI / Check Rust Code Format (push) Blocked by required conditions
CI / Build and Test Rust Project (push) Blocked by required conditions
CI / Build Rust Project on Windows (push) Waiting to run
CI / Lint Rust Code (push) Blocked by required conditions
CI / Check Generated Schemas are Up-to-Date (push) Blocked by required conditions
CI / Test and Lint Electron Desktop App (push) Blocked by required conditions
Live Provider Tests / check-fork (push) Waiting to run
Live Provider Tests / changes (push) Blocked by required conditions
Live Provider Tests / Build Binary (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (Code Execution) (push) Blocked by required conditions
Live Provider Tests / Compaction Tests (push) Blocked by required conditions
Live Provider Tests / goose server HTTP integration tests (push) Blocked by required conditions
Publish Docker Image / docker (push) Waiting to run
Scorecard supply-chain security / Scorecard analysis (push) Waiting to run
Cargo Deny / deny (push) Has been cancelled
Some checks failed
Canary / Prepare Version (push) Waiting to run
Canary / build-cli (push) Blocked by required conditions
Canary / Upload Install Script (push) Blocked by required conditions
Canary / bundle-desktop (push) Blocked by required conditions
Canary / bundle-desktop-intel (push) Blocked by required conditions
Canary / bundle-desktop-linux (push) Blocked by required conditions
Canary / bundle-desktop-windows (push) Blocked by required conditions
Canary / Release (push) Blocked by required conditions
Unused Dependencies / machete (push) Waiting to run
CI / changes (push) Waiting to run
CI / Check Rust Code Format (push) Blocked by required conditions
CI / Build and Test Rust Project (push) Blocked by required conditions
CI / Build Rust Project on Windows (push) Waiting to run
CI / Lint Rust Code (push) Blocked by required conditions
CI / Check Generated Schemas are Up-to-Date (push) Blocked by required conditions
CI / Test and Lint Electron Desktop App (push) Blocked by required conditions
Live Provider Tests / check-fork (push) Waiting to run
Live Provider Tests / changes (push) Blocked by required conditions
Live Provider Tests / Build Binary (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (Code Execution) (push) Blocked by required conditions
Live Provider Tests / Compaction Tests (push) Blocked by required conditions
Live Provider Tests / goose server HTTP integration tests (push) Blocked by required conditions
Publish Docker Image / docker (push) Waiting to run
Scorecard supply-chain security / Scorecard analysis (push) Waiting to run
Cargo Deny / deny (push) Has been cancelled
Signed-off-by: jh-block <jhugo@block.xyz>
This commit is contained in:
parent
5fa2a8b821
commit
de317d5445
15 changed files with 1181 additions and 88 deletions
|
|
@ -1604,6 +1604,9 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()>
|
||||||
source_url: file.download_url.clone(),
|
source_url: file.download_url.clone(),
|
||||||
settings: Default::default(),
|
settings: Default::default(),
|
||||||
size_bytes: file.size_bytes,
|
size_bytes: file.size_bytes,
|
||||||
|
mmproj_path: None,
|
||||||
|
mmproj_source_url: None,
|
||||||
|
mmproj_size_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use crate::routes::errors::ErrorResponse;
|
use crate::routes::errors::ErrorResponse;
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|
@ -13,9 +15,9 @@ use goose::providers::local_inference::{
|
||||||
available_inference_memory_bytes,
|
available_inference_memory_bytes,
|
||||||
hf_models::{resolve_model_spec, HfGgufFile},
|
hf_models::{resolve_model_spec, HfGgufFile},
|
||||||
local_model_registry::{
|
local_model_registry::{
|
||||||
default_settings_for_model, get_registry, is_featured_model, model_id_from_repo,
|
default_settings_for_model, featured_mmproj_spec, get_registry, is_featured_model,
|
||||||
LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus, ModelSettings,
|
model_id_from_repo, LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus,
|
||||||
FEATURED_MODELS,
|
ModelSettings, FEATURED_MODELS,
|
||||||
},
|
},
|
||||||
recommend_local_model,
|
recommend_local_model,
|
||||||
};
|
};
|
||||||
|
|
@ -47,10 +49,14 @@ pub struct LocalModelResponse {
|
||||||
pub status: ModelDownloadStatus,
|
pub status: ModelDownloadStatus,
|
||||||
pub recommended: bool,
|
pub recommended: bool,
|
||||||
pub settings: ModelSettings,
|
pub settings: ModelSettings,
|
||||||
|
pub vision_capable: bool,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub mmproj_status: Option<ModelDownloadStatus>,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
||||||
let mut entries_to_add = Vec::new();
|
let mut entries_to_add = Vec::new();
|
||||||
|
let mut mmproj_downloads_needed: Vec<(String, String, PathBuf)> = Vec::new();
|
||||||
|
|
||||||
for featured in FEATURED_MODELS {
|
for featured in FEATURED_MODELS {
|
||||||
let (repo_id, quantization) = match hf_models::parse_model_spec(featured.spec) {
|
let (repo_id, quantization) = match hf_models::parse_model_spec(featured.spec) {
|
||||||
|
|
@ -64,9 +70,28 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
||||||
let registry = get_registry()
|
let registry = get_registry()
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
||||||
if registry.has_model(&model_id) {
|
if let Some(existing) = registry.get_model(&model_id) {
|
||||||
|
let needs_backfill = existing.mmproj_path.is_none() && featured.mmproj.is_some();
|
||||||
|
let needs_download = existing.is_downloaded()
|
||||||
|
&& featured.mmproj.is_some()
|
||||||
|
&& !existing.mmproj_path.as_ref().is_some_and(|p| p.exists());
|
||||||
|
|
||||||
|
if needs_download {
|
||||||
|
if let Some(mmproj) = featured.mmproj.as_ref() {
|
||||||
|
let path = mmproj.local_path();
|
||||||
|
let url = format!(
|
||||||
|
"https://huggingface.co/{}/resolve/main/{}",
|
||||||
|
mmproj.repo, mmproj.filename
|
||||||
|
);
|
||||||
|
mmproj_downloads_needed.push((model_id.clone(), url, path));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !needs_backfill {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// Fall through to build the entry for sync_with_featured backfill
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let hf_file = match resolve_model_spec(featured.spec).await {
|
let hf_file = match resolve_model_spec(featured.spec).await {
|
||||||
|
|
@ -91,6 +116,8 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
||||||
|
|
||||||
let local_path = Paths::in_data_dir("models").join(&hf_file.filename);
|
let local_path = Paths::in_data_dir("models").join(&hf_file.filename);
|
||||||
|
|
||||||
|
// enrich_with_featured_mmproj is called by sync_with_featured/add_model,
|
||||||
|
// so we don't need to populate mmproj fields here.
|
||||||
entries_to_add.push(LocalModelEntry {
|
entries_to_add.push(LocalModelEntry {
|
||||||
id: model_id.clone(),
|
id: model_id.clone(),
|
||||||
repo_id,
|
repo_id,
|
||||||
|
|
@ -100,16 +127,60 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
||||||
source_url: hf_file.download_url,
|
source_url: hf_file.download_url,
|
||||||
settings: default_settings_for_model(&model_id),
|
settings: default_settings_for_model(&model_id),
|
||||||
size_bytes: hf_file.size_bytes,
|
size_bytes: hf_file.size_bytes,
|
||||||
|
mmproj_path: None,
|
||||||
|
mmproj_source_url: None,
|
||||||
|
mmproj_size_bytes: 0,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if !entries_to_add.is_empty() {
|
{
|
||||||
let mut registry = get_registry()
|
let mut registry = get_registry()
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
||||||
|
|
||||||
|
if !entries_to_add.is_empty() {
|
||||||
registry.sync_with_featured(entries_to_add);
|
registry.sync_with_featured(entries_to_add);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Backfill mmproj data for all registry models and collect any
|
||||||
|
// needed mmproj downloads for models already on disk.
|
||||||
|
for model in registry.list_models_mut() {
|
||||||
|
model.enrich_with_featured_mmproj();
|
||||||
|
if model.is_downloaded() {
|
||||||
|
if let Some(mmproj) = featured_mmproj_spec(&model.id) {
|
||||||
|
let path = mmproj.local_path();
|
||||||
|
if !path.exists() {
|
||||||
|
let url = format!(
|
||||||
|
"https://huggingface.co/{}/resolve/main/{}",
|
||||||
|
mmproj.repo, mmproj.filename
|
||||||
|
);
|
||||||
|
mmproj_downloads_needed.push((model.id.clone(), url, path));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = registry.save();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-download mmproj files for models that are already downloaded.
|
||||||
|
// Deduplicate by path since multiple quants share one mmproj file.
|
||||||
|
let dm = get_download_manager();
|
||||||
|
let mut started_paths = std::collections::HashSet::new();
|
||||||
|
for (model_id, url, path) in mmproj_downloads_needed {
|
||||||
|
if !path.exists() && started_paths.insert(path.clone()) {
|
||||||
|
let download_id = format!("{}-mmproj", model_id);
|
||||||
|
let dominated_by_active = dm
|
||||||
|
.get_progress(&download_id)
|
||||||
|
.is_some_and(|p| p.status == goose::download_manager::DownloadStatus::Downloading);
|
||||||
|
if !dominated_by_active {
|
||||||
|
tracing::info!(model_id = %model_id, "Auto-downloading vision encoder for existing model");
|
||||||
|
if let Err(e) = dm.download_model(download_id, url, path, None).await {
|
||||||
|
tracing::warn!(model_id = %model_id, error = %e, "Failed to start mmproj download");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -154,6 +225,28 @@ pub async fn list_local_models(
|
||||||
|
|
||||||
let size_bytes = entry.file_size();
|
let size_bytes = entry.file_size();
|
||||||
|
|
||||||
|
let vision_capable = entry.settings.vision_capable;
|
||||||
|
let mmproj_status = if vision_capable {
|
||||||
|
let ms = entry.mmproj_download_status();
|
||||||
|
Some(match ms {
|
||||||
|
RegistryDownloadStatus::NotDownloaded => ModelDownloadStatus::NotDownloaded,
|
||||||
|
RegistryDownloadStatus::Downloading {
|
||||||
|
progress_percent,
|
||||||
|
bytes_downloaded,
|
||||||
|
total_bytes,
|
||||||
|
speed_bps,
|
||||||
|
} => ModelDownloadStatus::Downloading {
|
||||||
|
progress_percent,
|
||||||
|
bytes_downloaded,
|
||||||
|
total_bytes,
|
||||||
|
speed_bps: Some(speed_bps),
|
||||||
|
},
|
||||||
|
RegistryDownloadStatus::Downloaded => ModelDownloadStatus::Downloaded,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
models.push(LocalModelResponse {
|
models.push(LocalModelResponse {
|
||||||
id: entry.id.clone(),
|
id: entry.id.clone(),
|
||||||
repo_id: entry.repo_id.clone(),
|
repo_id: entry.repo_id.clone(),
|
||||||
|
|
@ -163,6 +256,8 @@ pub async fn list_local_models(
|
||||||
status,
|
status,
|
||||||
recommended: recommended_id == entry.id,
|
recommended: recommended_id == entry.id,
|
||||||
settings: entry.settings.clone(),
|
settings: entry.settings.clone(),
|
||||||
|
vision_capable,
|
||||||
|
mmproj_status,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -276,16 +371,26 @@ pub async fn download_hf_model(
|
||||||
source_url: download_url.clone(),
|
source_url: download_url.clone(),
|
||||||
settings: default_settings_for_model(&model_id),
|
settings: default_settings_for_model(&model_id),
|
||||||
size_bytes: hf_file.size_bytes,
|
size_bytes: hf_file.size_bytes,
|
||||||
|
mmproj_path: None,
|
||||||
|
mmproj_source_url: None,
|
||||||
|
mmproj_size_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
// add_model enriches the entry with mmproj metadata from the featured table
|
||||||
|
let mmproj_path = {
|
||||||
let mut registry = get_registry()
|
let mut registry = get_registry()
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
||||||
registry
|
registry
|
||||||
.add_model(entry)
|
.add_model(entry)
|
||||||
.map_err(|e| ErrorResponse::internal(format!("{}", e)))?;
|
.map_err(|e| ErrorResponse::internal(format!("{}", e)))?;
|
||||||
}
|
registry.get_model(&model_id).and_then(|e| {
|
||||||
|
e.mmproj_path
|
||||||
|
.as_ref()
|
||||||
|
.zip(e.mmproj_source_url.as_ref())
|
||||||
|
.map(|(p, u)| (p.clone(), u.clone()))
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
let dm = get_download_manager();
|
let dm = get_download_manager();
|
||||||
dm.download_model(
|
dm.download_model(
|
||||||
|
|
@ -297,6 +402,19 @@ pub async fn download_hf_model(
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ErrorResponse::internal(format!("Download failed: {}", e)))?;
|
.map_err(|e| ErrorResponse::internal(format!("Download failed: {}", e)))?;
|
||||||
|
|
||||||
|
if let Some((mmproj_path, mmproj_url)) = mmproj_path {
|
||||||
|
if !mmproj_path.exists() {
|
||||||
|
dm.download_model(
|
||||||
|
format!("{}-mmproj", model_id),
|
||||||
|
mmproj_url,
|
||||||
|
mmproj_path,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ErrorResponse::internal(format!("mmproj download failed: {}", e)))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok((StatusCode::ACCEPTED, Json(model_id)))
|
Ok((StatusCode::ACCEPTED, Json(model_id)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -338,6 +456,7 @@ pub async fn cancel_local_model_download(
|
||||||
manager
|
manager
|
||||||
.cancel_download(&format!("{}-model", model_id))
|
.cancel_download(&format!("{}-model", model_id))
|
||||||
.map_err(|e| ErrorResponse::internal(format!("{}", e)))?;
|
.map_err(|e| ErrorResponse::internal(format!("{}", e)))?;
|
||||||
|
let _ = manager.cancel_download(&format!("{}-mmproj", model_id));
|
||||||
|
|
||||||
Ok(StatusCode::OK)
|
Ok(StatusCode::OK)
|
||||||
}
|
}
|
||||||
|
|
@ -351,14 +470,22 @@ pub async fn cancel_local_model_download(
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn delete_local_model(Path(model_id): Path<String>) -> Result<StatusCode, ErrorResponse> {
|
pub async fn delete_local_model(Path(model_id): Path<String>) -> Result<StatusCode, ErrorResponse> {
|
||||||
let local_path = {
|
let (local_path, mmproj_path, other_uses_mmproj) = {
|
||||||
let registry = get_registry()
|
let registry = get_registry()
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
||||||
let entry = registry
|
let entry = registry
|
||||||
.get_model(&model_id)
|
.get_model(&model_id)
|
||||||
.ok_or_else(|| ErrorResponse::not_found("Model not found"))?;
|
.ok_or_else(|| ErrorResponse::not_found("Model not found"))?;
|
||||||
entry.local_path.clone()
|
let lp = entry.local_path.clone();
|
||||||
|
let mp = entry.mmproj_path.clone();
|
||||||
|
// Check if another downloaded model shares this mmproj file
|
||||||
|
let shared = mp.as_ref().is_some_and(|target| {
|
||||||
|
registry.list_models().iter().any(|m| {
|
||||||
|
m.id != model_id && m.is_downloaded() && m.mmproj_path.as_ref() == Some(target)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
(lp, mp, shared)
|
||||||
};
|
};
|
||||||
|
|
||||||
if local_path.exists() {
|
if local_path.exists() {
|
||||||
|
|
@ -367,6 +494,14 @@ pub async fn delete_local_model(Path(model_id): Path<String>) -> Result<StatusCo
|
||||||
.map_err(|e| ErrorResponse::internal(format!("Failed to delete: {}", e)))?;
|
.map_err(|e| ErrorResponse::internal(format!("Failed to delete: {}", e)))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !other_uses_mmproj {
|
||||||
|
if let Some(mmproj) = mmproj_path {
|
||||||
|
if mmproj.exists() {
|
||||||
|
let _ = tokio::fs::remove_file(&mmproj).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Only remove non-featured models from registry (featured ones stay as placeholders)
|
// Only remove non-featured models from registry (featured ones stay as placeholders)
|
||||||
if !is_featured_model(&model_id) {
|
if !is_featured_model(&model_id) {
|
||||||
let mut registry = get_registry()
|
let mut registry = get_registry()
|
||||||
|
|
|
||||||
|
|
@ -179,7 +179,7 @@ tree-sitter-typescript = { workspace = true }
|
||||||
which = { workspace = true }
|
which = { workspace = true }
|
||||||
pctx_code_mode = { version = "^0.3.0", optional = true }
|
pctx_code_mode = { version = "^0.3.0", optional = true }
|
||||||
pulldown-cmark = "0.13.0"
|
pulldown-cmark = "0.13.0"
|
||||||
llama-cpp-2 = { version = "0.1.143", features = ["sampler"], optional = true }
|
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "mtmd"], optional = true }
|
||||||
encoding_rs = "0.8.35"
|
encoding_rs = "0.8.35"
|
||||||
pastey = "0.2.1"
|
pastey = "0.2.1"
|
||||||
shell-words = { workspace = true }
|
shell-words = { workspace = true }
|
||||||
|
|
@ -197,7 +197,7 @@ keyring = { version = "3.6.2", features = ["windows-native"] }
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
candle-core = { version = "0.9", default-features = false, features = ["metal"], optional = true }
|
candle-core = { version = "0.9", default-features = false, features = ["metal"], optional = true }
|
||||||
candle-nn = { version = "0.9", default-features = false, features = ["metal"], optional = true }
|
candle-nn = { version = "0.9", default-features = false, features = ["metal"], optional = true }
|
||||||
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "metal"], optional = true }
|
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "metal", "mtmd"], optional = true }
|
||||||
keyring = { version = "3.6.2", features = ["apple-native"] }
|
keyring = { version = "3.6.2", features = ["apple-native"] }
|
||||||
|
|
||||||
[target.'cfg(target_os = "linux")'.dependencies]
|
[target.'cfg(target_os = "linux")'.dependencies]
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ mod inference_emulated_tools;
|
||||||
mod inference_engine;
|
mod inference_engine;
|
||||||
mod inference_native_tools;
|
mod inference_native_tools;
|
||||||
pub mod local_model_registry;
|
pub mod local_model_registry;
|
||||||
|
pub(crate) mod multimodal;
|
||||||
mod tool_parsing;
|
mod tool_parsing;
|
||||||
|
|
||||||
use inference_emulated_tools::{
|
use inference_emulated_tools::{
|
||||||
|
|
@ -30,6 +31,7 @@ use llama_cpp_2::llama_backend::LlamaBackend;
|
||||||
use llama_cpp_2::model::params::LlamaModelParams;
|
use llama_cpp_2::model::params::LlamaModelParams;
|
||||||
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel};
|
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel};
|
||||||
use llama_cpp_2::{list_llama_ggml_backend_devices, LlamaBackendDeviceType, LogOptions};
|
use llama_cpp_2::{list_llama_ggml_backend_devices, LlamaBackendDeviceType, LogOptions};
|
||||||
|
use multimodal::ExtractedImage;
|
||||||
use rmcp::model::{Role, Tool};
|
use rmcp::model::{Role, Tool};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -114,14 +116,15 @@ const DEFAULT_MODEL: &str = "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M";
|
||||||
|
|
||||||
pub const LOCAL_LLM_MODEL_CONFIG_KEY: &str = "LOCAL_LLM_MODEL";
|
pub const LOCAL_LLM_MODEL_CONFIG_KEY: &str = "LOCAL_LLM_MODEL";
|
||||||
|
|
||||||
/// Resolve model path, context limit, and settings for a model ID from the registry.
|
pub struct ResolvedModelPaths {
|
||||||
pub fn resolve_model_path(
|
pub model_path: PathBuf,
|
||||||
model_id: &str,
|
pub context_limit: usize,
|
||||||
) -> Option<(
|
pub settings: crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||||
PathBuf,
|
pub mmproj_path: Option<PathBuf>,
|
||||||
usize,
|
}
|
||||||
crate::providers::local_inference::local_model_registry::ModelSettings,
|
|
||||||
)> {
|
/// Resolve model path, context limit, settings, and mmproj path for a model ID from the registry.
|
||||||
|
pub fn resolve_model_path(model_id: &str) -> Option<ResolvedModelPaths> {
|
||||||
use crate::providers::local_inference::local_model_registry::{
|
use crate::providers::local_inference::local_model_registry::{
|
||||||
default_settings_for_model, get_registry,
|
default_settings_for_model, get_registry,
|
||||||
};
|
};
|
||||||
|
|
@ -135,7 +138,15 @@ pub fn resolve_model_path(
|
||||||
// recognized (or with a different quantization) still get the right behavior.
|
// recognized (or with a different quantization) still get the right behavior.
|
||||||
let defaults = default_settings_for_model(model_id);
|
let defaults = default_settings_for_model(model_id);
|
||||||
settings.native_tool_calling = defaults.native_tool_calling;
|
settings.native_tool_calling = defaults.native_tool_calling;
|
||||||
return Some((entry.local_path.clone(), ctx, settings));
|
settings.vision_capable = defaults.vision_capable;
|
||||||
|
settings.mmproj_size_bytes = entry.mmproj_size_bytes;
|
||||||
|
let mmproj_path = entry.mmproj_path.as_ref().filter(|p| p.exists()).cloned();
|
||||||
|
return Some(ResolvedModelPaths {
|
||||||
|
model_path: entry.local_path.clone(),
|
||||||
|
context_limit: ctx,
|
||||||
|
settings,
|
||||||
|
mmproj_path,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -208,9 +219,33 @@ fn build_openai_messages_json(system: &str, messages: &[Message]) -> String {
|
||||||
|
|
||||||
let mut arr: Vec<Value> = vec![json!({"role": "system", "content": system})];
|
let mut arr: Vec<Value> = vec![json!({"role": "system", "content": system})];
|
||||||
arr.extend(format_messages(messages, &ImageFormat::OpenAi));
|
arr.extend(format_messages(messages, &ImageFormat::OpenAi));
|
||||||
|
strip_image_parts_from_messages(&mut arr);
|
||||||
serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string())
|
serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Remove `image_url` content parts from OpenAI-format messages JSON, replacing
|
||||||
|
/// each with a text note. This prevents an FFI crash in llama.cpp which does not
|
||||||
|
/// accept `image_url` content-part types.
|
||||||
|
fn strip_image_parts_from_messages(messages: &mut [Value]) {
|
||||||
|
let mut stripped = false;
|
||||||
|
for msg in messages.iter_mut() {
|
||||||
|
if let Some(content) = msg.get_mut("content").and_then(|c| c.as_array_mut()) {
|
||||||
|
for part in content.iter_mut() {
|
||||||
|
if part.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
||||||
|
*part = json!({
|
||||||
|
"type": "text",
|
||||||
|
"text": "[Image attached — image input is not supported with the currently selected model]"
|
||||||
|
});
|
||||||
|
stripped = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if stripped {
|
||||||
|
tracing::warn!("Stripped image content parts from messages — vision encoder not available for this model");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert a message into plain text for the emulator path's chat history.
|
/// Convert a message into plain text for the emulator path's chat history.
|
||||||
///
|
///
|
||||||
/// This is the emulator-path counterpart of [`format_messages`] used by the native
|
/// This is the emulator-path counterpart of [`format_messages`] used by the native
|
||||||
|
|
@ -269,6 +304,12 @@ fn extract_text_content(msg: &Message) -> String {
|
||||||
parts.push(format!("Command error: {}", e));
|
parts.push(format!("Command error: {}", e));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
MessageContent::Image(_) => {
|
||||||
|
parts.push(
|
||||||
|
"[Image attached — image input is not supported with the currently selected model]"
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -331,8 +372,9 @@ impl LocalInferenceProvider {
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
|
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||||
) -> Result<LoadedModel, ProviderError> {
|
) -> Result<LoadedModel, ProviderError> {
|
||||||
let (model_path, _context_limit, _) = resolve_model_path(model_id)
|
let resolved = resolve_model_path(model_id)
|
||||||
.ok_or_else(|| ProviderError::ExecutionError(format!("Unknown model: {}", model_id)))?;
|
.ok_or_else(|| ProviderError::ExecutionError(format!("Unknown model: {}", model_id)))?;
|
||||||
|
let model_path = resolved.model_path;
|
||||||
|
|
||||||
if !model_path.exists() {
|
if !model_path.exists() {
|
||||||
return Err(ProviderError::ExecutionError(format!(
|
return Err(ProviderError::ExecutionError(format!(
|
||||||
|
|
@ -368,9 +410,49 @@ impl LocalInferenceProvider {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mtmd_ctx = Self::init_mtmd_context(&model, &resolved.mmproj_path, settings);
|
||||||
|
|
||||||
tracing::info!(model_id = model_id, "Model loaded successfully");
|
tracing::info!(model_id = model_id, "Model loaded successfully");
|
||||||
|
|
||||||
Ok(LoadedModel { model, template })
|
Ok(LoadedModel {
|
||||||
|
model,
|
||||||
|
template,
|
||||||
|
mtmd_ctx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn init_mtmd_context(
|
||||||
|
model: &LlamaModel,
|
||||||
|
mmproj_path: &Option<PathBuf>,
|
||||||
|
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||||
|
) -> Option<llama_cpp_2::mtmd::MtmdContext> {
|
||||||
|
use llama_cpp_2::mtmd::{MtmdContext, MtmdContextParams};
|
||||||
|
|
||||||
|
let mmproj_path = mmproj_path.as_ref().filter(|p| p.exists())?;
|
||||||
|
|
||||||
|
let params = MtmdContextParams {
|
||||||
|
use_gpu: true,
|
||||||
|
n_threads: settings
|
||||||
|
.n_threads
|
||||||
|
.unwrap_or_else(|| MtmdContextParams::default().n_threads),
|
||||||
|
..MtmdContextParams::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
match MtmdContext::init_from_file(mmproj_path.to_str().unwrap_or_default(), model, ¶ms)
|
||||||
|
{
|
||||||
|
Ok(ctx) => {
|
||||||
|
tracing::info!(
|
||||||
|
vision = ctx.support_vision(),
|
||||||
|
audio = ctx.support_audio(),
|
||||||
|
"Multimodal context initialized"
|
||||||
|
);
|
||||||
|
Some(ctx)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "Failed to init multimodal context");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -453,13 +535,11 @@ impl Provider for LocalInferenceProvider {
|
||||||
messages: &[Message],
|
messages: &[Message],
|
||||||
tools: &[Tool],
|
tools: &[Tool],
|
||||||
) -> Result<MessageStream, ProviderError> {
|
) -> Result<MessageStream, ProviderError> {
|
||||||
let (_model_path, model_context_limit, model_settings) =
|
let resolved = resolve_model_path(&model_config.model_name).ok_or_else(|| {
|
||||||
resolve_model_path(&model_config.model_name).ok_or_else(|| {
|
ProviderError::ExecutionError(format!("Model not found: {}", model_config.model_name))
|
||||||
ProviderError::ExecutionError(format!(
|
|
||||||
"Model not found: {}",
|
|
||||||
model_config.model_name
|
|
||||||
))
|
|
||||||
})?;
|
})?;
|
||||||
|
let model_context_limit = resolved.context_limit;
|
||||||
|
let model_settings = resolved.settings;
|
||||||
|
|
||||||
// Ensure model is loaded — unload any other models first to free memory.
|
// Ensure model is loaded — unload any other models first to free memory.
|
||||||
{
|
{
|
||||||
|
|
@ -503,6 +583,18 @@ impl Provider for LocalInferenceProvider {
|
||||||
system.to_string()
|
system.to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Extract images for vision-capable models, replacing them with markers.
|
||||||
|
// For non-vision models, leave messages unchanged (existing strip logic handles them).
|
||||||
|
let has_vision = resolved.mmproj_path.is_some();
|
||||||
|
let marker = llama_cpp_2::mtmd::mtmd_default_marker();
|
||||||
|
let (images, vision_messages): (Vec<ExtractedImage>, Option<Vec<Message>>) = if has_vision {
|
||||||
|
let (imgs, msgs) = multimodal::extract_images_from_messages(messages, marker);
|
||||||
|
(imgs, Some(msgs))
|
||||||
|
} else {
|
||||||
|
(Vec::new(), None)
|
||||||
|
};
|
||||||
|
let effective_messages: &[Message] = vision_messages.as_deref().unwrap_or(messages);
|
||||||
|
|
||||||
// Build chat messages for the template
|
// Build chat messages for the template
|
||||||
let mut chat_messages =
|
let mut chat_messages =
|
||||||
vec![
|
vec![
|
||||||
|
|
@ -529,7 +621,7 @@ impl Provider for LocalInferenceProvider {
|
||||||
})?];
|
})?];
|
||||||
}
|
}
|
||||||
|
|
||||||
for msg in messages {
|
for msg in effective_messages {
|
||||||
let role = match msg.role {
|
let role = match msg.role {
|
||||||
Role::User => "user",
|
Role::User => "user",
|
||||||
Role::Assistant => "assistant",
|
Role::Assistant => "assistant",
|
||||||
|
|
@ -553,7 +645,10 @@ impl Provider for LocalInferenceProvider {
|
||||||
};
|
};
|
||||||
|
|
||||||
let oai_messages_json = if model_settings.use_jinja || native_tool_calling {
|
let oai_messages_json = if model_settings.use_jinja || native_tool_calling {
|
||||||
Some(build_openai_messages_json(&system_prompt, messages))
|
Some(build_openai_messages_json(
|
||||||
|
&system_prompt,
|
||||||
|
effective_messages,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
@ -563,6 +658,7 @@ impl Provider for LocalInferenceProvider {
|
||||||
let model_name = model_config.model_name.clone();
|
let model_name = model_config.model_name.clone();
|
||||||
let context_limit = model_context_limit;
|
let context_limit = model_context_limit;
|
||||||
let settings = model_settings;
|
let settings = model_settings;
|
||||||
|
let mmproj_path = resolved.mmproj_path.clone();
|
||||||
|
|
||||||
let log_payload = serde_json::json!({
|
let log_payload = serde_json::json!({
|
||||||
"system": &system_prompt,
|
"system": &system_prompt,
|
||||||
|
|
@ -604,8 +700,8 @@ impl Provider for LocalInferenceProvider {
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
let model_guard = model_arc.blocking_lock();
|
let mut model_guard = model_arc.blocking_lock();
|
||||||
let loaded = match model_guard.as_ref() {
|
let loaded = match model_guard.as_mut() {
|
||||||
Some(l) => l,
|
Some(l) => l,
|
||||||
None => {
|
None => {
|
||||||
send_err!(ProviderError::ExecutionError(
|
send_err!(ProviderError::ExecutionError(
|
||||||
|
|
@ -614,6 +710,16 @@ impl Provider for LocalInferenceProvider {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Lazily initialize the multimodal context if the vision encoder
|
||||||
|
// was downloaded after the model was loaded.
|
||||||
|
if !images.is_empty() && loaded.mtmd_ctx.is_none() {
|
||||||
|
loaded.mtmd_ctx = LocalInferenceProvider::init_mtmd_context(
|
||||||
|
&loaded.model,
|
||||||
|
&mmproj_path,
|
||||||
|
&settings,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let message_id = Uuid::new_v4().to_string();
|
let message_id = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
let mut gen_ctx = GenerationContext {
|
let mut gen_ctx = GenerationContext {
|
||||||
|
|
@ -626,6 +732,7 @@ impl Provider for LocalInferenceProvider {
|
||||||
message_id: &message_id,
|
message_id: &message_id,
|
||||||
tx: &tx,
|
tx: &tx,
|
||||||
log: &mut log,
|
log: &mut log,
|
||||||
|
images: &images,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = if use_emulator {
|
let result = if use_emulator {
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,8 @@ use std::borrow::Cow;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use super::inference_engine::{
|
use super::inference_engine::{
|
||||||
create_and_prefill_context, generation_loop, validate_and_compute_context, GenerationContext,
|
create_and_prefill_context, create_and_prefill_multimodal, generation_loop,
|
||||||
TokenAction,
|
validate_and_compute_context, GenerationContext, TokenAction,
|
||||||
};
|
};
|
||||||
use super::{finalize_usage, StreamSender, CODE_EXECUTION_TOOL, SHELL_TOOL};
|
use super::{finalize_usage, StreamSender, CODE_EXECUTION_TOOL, SHELL_TOOL};
|
||||||
|
|
||||||
|
|
@ -370,26 +370,32 @@ pub(super) fn generate_with_emulated_tools(
|
||||||
ProviderError::ExecutionError(format!("Failed to apply chat template: {}", e))
|
ProviderError::ExecutionError(format!("Failed to apply chat template: {}", e))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let (mut llama_ctx, prompt_token_count, effective_ctx) = if !ctx.images.is_empty() {
|
||||||
|
create_and_prefill_multimodal(
|
||||||
|
ctx.loaded,
|
||||||
|
ctx.runtime,
|
||||||
|
&prompt,
|
||||||
|
ctx.images,
|
||||||
|
ctx.context_limit,
|
||||||
|
ctx.settings,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
let tokens = ctx
|
let tokens = ctx
|
||||||
.loaded
|
.loaded
|
||||||
.model
|
.model
|
||||||
.str_to_token(&prompt, AddBos::Never)
|
.str_to_token(&prompt, AddBos::Never)
|
||||||
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
|
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
|
||||||
|
let (ptc, ectx) = validate_and_compute_context(
|
||||||
let (prompt_token_count, effective_ctx) = validate_and_compute_context(
|
|
||||||
ctx.loaded,
|
ctx.loaded,
|
||||||
ctx.runtime,
|
ctx.runtime,
|
||||||
tokens.len(),
|
tokens.len(),
|
||||||
ctx.context_limit,
|
ctx.context_limit,
|
||||||
ctx.settings,
|
ctx.settings,
|
||||||
)?;
|
)?;
|
||||||
let mut llama_ctx = create_and_prefill_context(
|
let lctx =
|
||||||
ctx.loaded,
|
create_and_prefill_context(ctx.loaded, ctx.runtime, &tokens, ectx, ctx.settings)?;
|
||||||
ctx.runtime,
|
(lctx, ptc, ectx)
|
||||||
&tokens,
|
};
|
||||||
effective_ctx,
|
|
||||||
ctx.settings,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let message_id = ctx.message_id;
|
let message_id = ctx.message_id;
|
||||||
let tx = ctx.tx;
|
let tx = ctx.tx;
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
use crate::providers::errors::ProviderError;
|
use crate::providers::errors::ProviderError;
|
||||||
use crate::providers::local_inference::local_model_registry::ModelSettings;
|
use crate::providers::local_inference::local_model_registry::ModelSettings;
|
||||||
|
use crate::providers::local_inference::multimodal::ExtractedImage;
|
||||||
use crate::providers::utils::RequestLog;
|
use crate::providers::utils::RequestLog;
|
||||||
use llama_cpp_2::context::params::LlamaContextParams;
|
use llama_cpp_2::context::params::LlamaContextParams;
|
||||||
use llama_cpp_2::llama_batch::LlamaBatch;
|
use llama_cpp_2::llama_batch::LlamaBatch;
|
||||||
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel};
|
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel};
|
||||||
|
use llama_cpp_2::mtmd::{MtmdBitmap, MtmdContext, MtmdInputText};
|
||||||
use llama_cpp_2::sampling::LlamaSampler;
|
use llama_cpp_2::sampling::LlamaSampler;
|
||||||
use std::num::NonZeroU32;
|
use std::num::NonZeroU32;
|
||||||
|
|
||||||
|
|
@ -19,11 +21,14 @@ pub(super) struct GenerationContext<'a> {
|
||||||
pub message_id: &'a str,
|
pub message_id: &'a str,
|
||||||
pub tx: &'a StreamSender,
|
pub tx: &'a StreamSender,
|
||||||
pub log: &'a mut RequestLog,
|
pub log: &'a mut RequestLog,
|
||||||
|
pub images: &'a [ExtractedImage],
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) struct LoadedModel {
|
pub(super) struct LoadedModel {
|
||||||
pub model: LlamaModel,
|
pub model: LlamaModel,
|
||||||
pub template: LlamaChatTemplate,
|
pub template: LlamaChatTemplate,
|
||||||
|
/// Multimodal context for vision models. None for text-only models.
|
||||||
|
pub mtmd_ctx: Option<MtmdContext>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Estimate the maximum context length that can fit in available accelerator/CPU
|
/// Estimate the maximum context length that can fit in available accelerator/CPU
|
||||||
|
|
@ -33,11 +38,13 @@ pub(super) struct LoadedModel {
|
||||||
pub(super) fn estimate_max_context_for_memory(
|
pub(super) fn estimate_max_context_for_memory(
|
||||||
model: &LlamaModel,
|
model: &LlamaModel,
|
||||||
runtime: &InferenceRuntime,
|
runtime: &InferenceRuntime,
|
||||||
|
mmproj_overhead_bytes: u64,
|
||||||
) -> Option<usize> {
|
) -> Option<usize> {
|
||||||
let available = super::available_inference_memory_bytes(runtime);
|
let raw_available = super::available_inference_memory_bytes(runtime);
|
||||||
if available == 0 {
|
if raw_available == 0 {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
let available = raw_available.saturating_sub(mmproj_overhead_bytes);
|
||||||
|
|
||||||
// Reserve memory for computation scratch buffers (attention, etc.) and other overhead.
|
// Reserve memory for computation scratch buffers (attention, etc.) and other overhead.
|
||||||
// The compute buffer can be 40-50% of the KV cache size for large models, so we
|
// The compute buffer can be 40-50% of the KV cache size for large models, so we
|
||||||
|
|
@ -209,7 +216,12 @@ pub(super) fn validate_and_compute_context(
|
||||||
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
|
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||||
) -> Result<(usize, usize), ProviderError> {
|
) -> Result<(usize, usize), ProviderError> {
|
||||||
let n_ctx_train = loaded.model.n_ctx_train() as usize;
|
let n_ctx_train = loaded.model.n_ctx_train() as usize;
|
||||||
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime);
|
let mmproj_overhead = if loaded.mtmd_ctx.is_some() {
|
||||||
|
settings.mmproj_size_bytes
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime, mmproj_overhead);
|
||||||
let effective_ctx = effective_context_size(
|
let effective_ctx = effective_context_size(
|
||||||
prompt_token_count,
|
prompt_token_count,
|
||||||
settings,
|
settings,
|
||||||
|
|
@ -261,6 +273,80 @@ pub(super) fn create_and_prefill_context<'model>(
|
||||||
Ok(ctx)
|
Ok(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tokenize text + images via mtmd and prefill the context.
|
||||||
|
///
|
||||||
|
/// Returns the llama context, the number of prompt tokens consumed,
|
||||||
|
/// and the effective context size.
|
||||||
|
pub(super) fn create_and_prefill_multimodal<'model>(
|
||||||
|
loaded: &'model LoadedModel,
|
||||||
|
runtime: &InferenceRuntime,
|
||||||
|
prompt_text: &str,
|
||||||
|
images: &[ExtractedImage],
|
||||||
|
context_limit: usize,
|
||||||
|
settings: &ModelSettings,
|
||||||
|
) -> Result<(llama_cpp_2::context::LlamaContext<'model>, usize, usize), ProviderError> {
|
||||||
|
let mtmd_ctx = loaded.mtmd_ctx.as_ref().ok_or_else(|| {
|
||||||
|
ProviderError::ExecutionError(
|
||||||
|
"This model does not have vision support. Download the vision encoder from \
|
||||||
|
Settings > Local Inference, or use a text-only message."
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let bitmaps: Vec<MtmdBitmap> = images
|
||||||
|
.iter()
|
||||||
|
.map(|img| {
|
||||||
|
MtmdBitmap::from_buffer(mtmd_ctx, &img.bytes)
|
||||||
|
.map_err(|e| ProviderError::ExecutionError(format!("Failed to decode image: {e}")))
|
||||||
|
})
|
||||||
|
.collect::<Result<_, _>>()?;
|
||||||
|
|
||||||
|
let bitmap_refs: Vec<&MtmdBitmap> = bitmaps.iter().collect();
|
||||||
|
|
||||||
|
let input_text = MtmdInputText {
|
||||||
|
text: prompt_text.to_string(),
|
||||||
|
add_special: true,
|
||||||
|
parse_special: true,
|
||||||
|
};
|
||||||
|
let chunks = mtmd_ctx.tokenize(input_text, &bitmap_refs).map_err(|e| {
|
||||||
|
ProviderError::ExecutionError(format!("Multimodal tokenization failed: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let prompt_token_count = chunks.total_tokens();
|
||||||
|
|
||||||
|
let n_ctx_train = loaded.model.n_ctx_train() as usize;
|
||||||
|
let mmproj_overhead = settings.mmproj_size_bytes;
|
||||||
|
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime, mmproj_overhead);
|
||||||
|
let effective_ctx = effective_context_size(
|
||||||
|
prompt_token_count,
|
||||||
|
settings,
|
||||||
|
context_limit,
|
||||||
|
n_ctx_train,
|
||||||
|
memory_max_ctx,
|
||||||
|
);
|
||||||
|
|
||||||
|
let min_generation_headroom = 512;
|
||||||
|
if prompt_token_count + min_generation_headroom > effective_ctx {
|
||||||
|
return Err(ProviderError::ContextLengthExceeded(format!(
|
||||||
|
"Multimodal prompt ({prompt_token_count} tokens including images) exceeds \
|
||||||
|
context limit ({effective_ctx} tokens)",
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ctx_params = build_context_params(effective_ctx as u32, settings);
|
||||||
|
let llama_ctx = loaded
|
||||||
|
.model
|
||||||
|
.new_context(runtime.backend(), ctx_params)
|
||||||
|
.map_err(|e| ProviderError::ExecutionError(format!("Failed to create context: {e}")))?;
|
||||||
|
|
||||||
|
let n_batch = llama_ctx.n_batch() as i32;
|
||||||
|
let _n_past = chunks
|
||||||
|
.eval_chunks(mtmd_ctx, &llama_ctx, 0, 0, n_batch, true)
|
||||||
|
.map_err(|e| ProviderError::ExecutionError(format!("Multimodal eval failed: {e}")))?;
|
||||||
|
|
||||||
|
Ok((llama_ctx, prompt_token_count, effective_ctx))
|
||||||
|
}
|
||||||
|
|
||||||
/// Action to take after processing a generated token piece.
|
/// Action to take after processing a generated token piece.
|
||||||
pub(super) enum TokenAction {
|
pub(super) enum TokenAction {
|
||||||
Continue,
|
Continue,
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,9 @@ use uuid::Uuid;
|
||||||
|
|
||||||
use super::finalize_usage;
|
use super::finalize_usage;
|
||||||
use super::inference_engine::{
|
use super::inference_engine::{
|
||||||
context_cap, create_and_prefill_context, estimate_max_context_for_memory, generation_loop,
|
context_cap, create_and_prefill_context, create_and_prefill_multimodal,
|
||||||
validate_and_compute_context, GenerationContext, TokenAction,
|
estimate_max_context_for_memory, generation_loop, validate_and_compute_context,
|
||||||
|
GenerationContext, TokenAction,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(super) fn generate_with_native_tools(
|
pub(super) fn generate_with_native_tools(
|
||||||
|
|
@ -21,7 +22,13 @@ pub(super) fn generate_with_native_tools(
|
||||||
) -> Result<(), ProviderError> {
|
) -> Result<(), ProviderError> {
|
||||||
let min_generation_headroom = 512;
|
let min_generation_headroom = 512;
|
||||||
let n_ctx_train = ctx.loaded.model.n_ctx_train() as usize;
|
let n_ctx_train = ctx.loaded.model.n_ctx_train() as usize;
|
||||||
let memory_max_ctx = estimate_max_context_for_memory(&ctx.loaded.model, ctx.runtime);
|
let mmproj_overhead = if ctx.loaded.mtmd_ctx.is_some() {
|
||||||
|
ctx.settings.mmproj_size_bytes
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let memory_max_ctx =
|
||||||
|
estimate_max_context_for_memory(&ctx.loaded.model, ctx.runtime, mmproj_overhead);
|
||||||
let cap = context_cap(ctx.settings, ctx.context_limit, n_ctx_train, memory_max_ctx);
|
let cap = context_cap(ctx.settings, ctx.context_limit, n_ctx_train, memory_max_ctx);
|
||||||
let token_budget = cap.saturating_sub(min_generation_headroom);
|
let token_budget = cap.saturating_sub(min_generation_headroom);
|
||||||
|
|
||||||
|
|
@ -61,6 +68,8 @@ pub(super) fn generate_with_native_tools(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let estimated_image_tokens = ctx.images.len() * ctx.settings.image_token_estimate;
|
||||||
|
|
||||||
let template_result = match apply_template(full_tools_json) {
|
let template_result = match apply_template(full_tools_json) {
|
||||||
Ok(r) => {
|
Ok(r) => {
|
||||||
let token_count = ctx
|
let token_count = ctx
|
||||||
|
|
@ -69,7 +78,7 @@ pub(super) fn generate_with_native_tools(
|
||||||
.str_to_token(&r.prompt, AddBos::Never)
|
.str_to_token(&r.prompt, AddBos::Never)
|
||||||
.map(|t| t.len())
|
.map(|t| t.len())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
if token_count > token_budget {
|
if token_count + estimated_image_tokens > token_budget {
|
||||||
apply_template(compact_tools).unwrap_or(r)
|
apply_template(compact_tools).unwrap_or(r)
|
||||||
} else {
|
} else {
|
||||||
r
|
r
|
||||||
|
|
@ -85,26 +94,32 @@ pub(super) fn generate_with_native_tools(
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let (mut llama_ctx, prompt_token_count, effective_ctx) = if !ctx.images.is_empty() {
|
||||||
|
create_and_prefill_multimodal(
|
||||||
|
ctx.loaded,
|
||||||
|
ctx.runtime,
|
||||||
|
&template_result.prompt,
|
||||||
|
ctx.images,
|
||||||
|
ctx.context_limit,
|
||||||
|
ctx.settings,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
let tokens = ctx
|
let tokens = ctx
|
||||||
.loaded
|
.loaded
|
||||||
.model
|
.model
|
||||||
.str_to_token(&template_result.prompt, AddBos::Never)
|
.str_to_token(&template_result.prompt, AddBos::Never)
|
||||||
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
|
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
|
||||||
|
let (ptc, ectx) = validate_and_compute_context(
|
||||||
let (prompt_token_count, effective_ctx) = validate_and_compute_context(
|
|
||||||
ctx.loaded,
|
ctx.loaded,
|
||||||
ctx.runtime,
|
ctx.runtime,
|
||||||
tokens.len(),
|
tokens.len(),
|
||||||
ctx.context_limit,
|
ctx.context_limit,
|
||||||
ctx.settings,
|
ctx.settings,
|
||||||
)?;
|
)?;
|
||||||
let mut llama_ctx = create_and_prefill_context(
|
let lctx =
|
||||||
ctx.loaded,
|
create_and_prefill_context(ctx.loaded, ctx.runtime, &tokens, ectx, ctx.settings)?;
|
||||||
ctx.runtime,
|
(lctx, ptc, ectx)
|
||||||
&tokens,
|
};
|
||||||
effective_ctx,
|
|
||||||
ctx.settings,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let message_id = ctx.message_id;
|
let message_id = ctx.message_id;
|
||||||
let tx = ctx.tx;
|
let tx = ctx.tx;
|
||||||
|
|
|
||||||
|
|
@ -62,12 +62,27 @@ pub struct ModelSettings {
|
||||||
pub use_jinja: bool,
|
pub use_jinja: bool,
|
||||||
#[serde(default = "default_true")]
|
#[serde(default = "default_true")]
|
||||||
pub enable_thinking: bool,
|
pub enable_thinking: bool,
|
||||||
|
/// Whether this model architecture supports vision input.
|
||||||
|
/// Derived from the featured model table, not user-configurable.
|
||||||
|
#[serde(default)]
|
||||||
|
pub vision_capable: bool,
|
||||||
|
/// Estimated tokens per image for budget planning before mtmd tokenization.
|
||||||
|
/// The actual count is determined after tokenization via `chunks.total_tokens()`.
|
||||||
|
#[serde(default = "default_image_token_estimate")]
|
||||||
|
pub image_token_estimate: usize,
|
||||||
|
/// Size of the mmproj file in bytes, used for memory accounting.
|
||||||
|
#[serde(default)]
|
||||||
|
pub mmproj_size_bytes: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_true() -> bool {
|
fn default_true() -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_image_token_estimate() -> usize {
|
||||||
|
256
|
||||||
|
}
|
||||||
|
|
||||||
fn default_repeat_penalty() -> f32 {
|
fn default_repeat_penalty() -> f32 {
|
||||||
1.0
|
1.0
|
||||||
}
|
}
|
||||||
|
|
@ -94,41 +109,75 @@ impl Default for ModelSettings {
|
||||||
native_tool_calling: false,
|
native_tool_calling: false,
|
||||||
use_jinja: false,
|
use_jinja: false,
|
||||||
enable_thinking: true,
|
enable_thinking: true,
|
||||||
|
vision_capable: false,
|
||||||
|
image_token_estimate: default_image_token_estimate(),
|
||||||
|
mmproj_size_bytes: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// HuggingFace repo + filename for multimodal projection weights (vision encoder).
|
||||||
|
pub struct MmprojSpec {
|
||||||
|
pub repo: &'static str,
|
||||||
|
pub filename: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MmprojSpec {
|
||||||
|
/// Local path for this mmproj, namespaced by repo to avoid collisions
|
||||||
|
/// between different models that use the same filename.
|
||||||
|
pub fn local_path(&self) -> std::path::PathBuf {
|
||||||
|
let repo_name = self.repo.split('/').next_back().unwrap_or(self.repo);
|
||||||
|
Paths::in_data_dir("models")
|
||||||
|
.join(repo_name)
|
||||||
|
.join(self.filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct FeaturedModel {
|
pub struct FeaturedModel {
|
||||||
/// HuggingFace spec in "author/repo-GGUF:quantization" format.
|
/// HuggingFace spec in "author/repo-GGUF:quantization" format.
|
||||||
pub spec: &'static str,
|
pub spec: &'static str,
|
||||||
/// Whether this model's GGUF template supports native tool calling via llama.cpp.
|
/// Whether this model's GGUF template supports native tool calling via llama.cpp.
|
||||||
pub native_tool_calling: bool,
|
pub native_tool_calling: bool,
|
||||||
|
/// Multimodal projection weights spec. None for text-only models.
|
||||||
|
pub mmproj: Option<MmprojSpec>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const FEATURED_MODELS: &[FeaturedModel] = &[
|
pub const FEATURED_MODELS: &[FeaturedModel] = &[
|
||||||
FeaturedModel {
|
FeaturedModel {
|
||||||
spec: "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",
|
spec: "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",
|
||||||
native_tool_calling: false,
|
native_tool_calling: false,
|
||||||
|
mmproj: None,
|
||||||
},
|
},
|
||||||
FeaturedModel {
|
FeaturedModel {
|
||||||
spec: "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",
|
spec: "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",
|
||||||
native_tool_calling: false,
|
native_tool_calling: false,
|
||||||
|
mmproj: None,
|
||||||
},
|
},
|
||||||
FeaturedModel {
|
FeaturedModel {
|
||||||
spec: "bartowski/Hermes-2-Pro-Mistral-7B-GGUF:Q4_K_M",
|
spec: "bartowski/Hermes-2-Pro-Mistral-7B-GGUF:Q4_K_M",
|
||||||
native_tool_calling: false,
|
native_tool_calling: false,
|
||||||
|
mmproj: None,
|
||||||
},
|
},
|
||||||
FeaturedModel {
|
FeaturedModel {
|
||||||
spec: "bartowski/Mistral-Small-24B-Instruct-2501-GGUF:Q4_K_M",
|
spec: "bartowski/Mistral-Small-24B-Instruct-2501-GGUF:Q4_K_M",
|
||||||
native_tool_calling: false,
|
native_tool_calling: false,
|
||||||
|
mmproj: None,
|
||||||
},
|
},
|
||||||
FeaturedModel {
|
FeaturedModel {
|
||||||
spec: "unsloth/gemma-4-E4B-it-GGUF:Q4_K_M",
|
spec: "unsloth/gemma-4-E4B-it-GGUF:Q4_K_M",
|
||||||
native_tool_calling: true,
|
native_tool_calling: true,
|
||||||
|
mmproj: Some(MmprojSpec {
|
||||||
|
repo: "unsloth/gemma-4-E4B-it-GGUF",
|
||||||
|
filename: "mmproj-BF16.gguf",
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
FeaturedModel {
|
FeaturedModel {
|
||||||
spec: "unsloth/gemma-4-26B-A4B-it-GGUF:Q4_K_M",
|
spec: "unsloth/gemma-4-26B-A4B-it-GGUF:Q4_K_M",
|
||||||
native_tool_calling: true,
|
native_tool_calling: true,
|
||||||
|
mmproj: Some(MmprojSpec {
|
||||||
|
repo: "unsloth/gemma-4-26B-A4B-it-GGUF",
|
||||||
|
filename: "mmproj-BF16.gguf",
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
@ -144,10 +193,25 @@ pub fn default_settings_for_model(model_id: &str) -> ModelSettings {
|
||||||
});
|
});
|
||||||
ModelSettings {
|
ModelSettings {
|
||||||
native_tool_calling: featured.is_some_and(|m| m.native_tool_calling),
|
native_tool_calling: featured.is_some_and(|m| m.native_tool_calling),
|
||||||
|
vision_capable: featured.is_some_and(|m| m.mmproj.is_some()),
|
||||||
..ModelSettings::default()
|
..ModelSettings::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Look up the `MmprojSpec` for a featured model by its model ID.
|
||||||
|
pub fn featured_mmproj_spec(model_id: &str) -> Option<&'static MmprojSpec> {
|
||||||
|
use super::hf_models::parse_model_spec;
|
||||||
|
let model_repo = model_id.split(':').next().unwrap_or(model_id);
|
||||||
|
FEATURED_MODELS.iter().find_map(|m| {
|
||||||
|
if let Ok((repo_id, _quant)) = parse_model_spec(m.spec) {
|
||||||
|
if repo_id == model_repo {
|
||||||
|
return m.mmproj.as_ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if a model ID corresponds to a featured model.
|
/// Check if a model ID corresponds to a featured model.
|
||||||
pub fn is_featured_model(model_id: &str) -> bool {
|
pub fn is_featured_model(model_id: &str) -> bool {
|
||||||
use super::hf_models::parse_model_spec;
|
use super::hf_models::parse_model_spec;
|
||||||
|
|
@ -181,9 +245,42 @@ pub struct LocalModelEntry {
|
||||||
pub settings: ModelSettings,
|
pub settings: ModelSettings,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub size_bytes: u64,
|
pub size_bytes: u64,
|
||||||
|
/// Local path to the multimodal projection GGUF (vision encoder).
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub mmproj_path: Option<PathBuf>,
|
||||||
|
/// Download URL for the mmproj file.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub mmproj_source_url: Option<String>,
|
||||||
|
/// Size of the mmproj file in bytes.
|
||||||
|
#[serde(default)]
|
||||||
|
pub mmproj_size_bytes: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LocalModelEntry {
|
impl LocalModelEntry {
|
||||||
|
/// Populate mmproj metadata and vision settings from the featured model
|
||||||
|
/// table if this model's repo has a known vision encoder.
|
||||||
|
pub fn enrich_with_featured_mmproj(&mut self) {
|
||||||
|
if let Some(mmproj) = featured_mmproj_spec(&self.id) {
|
||||||
|
let path = mmproj.local_path();
|
||||||
|
if self.mmproj_path.as_ref() != Some(&path) {
|
||||||
|
self.mmproj_path = Some(path.clone());
|
||||||
|
self.mmproj_source_url = Some(format!(
|
||||||
|
"https://huggingface.co/{}/resolve/main/{}",
|
||||||
|
mmproj.repo, mmproj.filename
|
||||||
|
));
|
||||||
|
}
|
||||||
|
self.settings.vision_capable = true;
|
||||||
|
if self.mmproj_size_bytes == 0 || self.settings.mmproj_size_bytes == 0 {
|
||||||
|
if let Ok(meta) = std::fs::metadata(&path) {
|
||||||
|
self.mmproj_size_bytes = meta.len();
|
||||||
|
self.settings.mmproj_size_bytes = meta.len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let defaults = default_settings_for_model(&self.id);
|
||||||
|
self.settings.native_tool_calling = defaults.native_tool_calling;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn is_downloaded(&self) -> bool {
|
pub fn is_downloaded(&self) -> bool {
|
||||||
self.local_path.exists()
|
self.local_path.exists()
|
||||||
}
|
}
|
||||||
|
|
@ -219,6 +316,36 @@ impl LocalModelEntry {
|
||||||
ModelDownloadStatus::NotDownloaded
|
ModelDownloadStatus::NotDownloaded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn has_vision(&self) -> bool {
|
||||||
|
self.mmproj_path.as_ref().is_some_and(|p| p.exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mmproj_download_status(&self) -> ModelDownloadStatus {
|
||||||
|
if let Some(path) = &self.mmproj_path {
|
||||||
|
if path.exists() {
|
||||||
|
return ModelDownloadStatus::Downloaded;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return ModelDownloadStatus::NotDownloaded;
|
||||||
|
}
|
||||||
|
|
||||||
|
let download_id = format!("{}-mmproj", self.id);
|
||||||
|
let manager = get_download_manager();
|
||||||
|
if let Some(progress) = manager.get_progress(&download_id) {
|
||||||
|
return match progress.status {
|
||||||
|
DownloadStatus::Downloading => ModelDownloadStatus::Downloading {
|
||||||
|
progress_percent: progress.progress_percent,
|
||||||
|
bytes_downloaded: progress.bytes_downloaded,
|
||||||
|
total_bytes: progress.total_bytes,
|
||||||
|
speed_bps: progress.speed_bps.unwrap_or(0),
|
||||||
|
},
|
||||||
|
_ => ModelDownloadStatus::NotDownloaded,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelDownloadStatus::NotDownloaded
|
||||||
|
}
|
||||||
|
|
||||||
pub fn file_size(&self) -> u64 {
|
pub fn file_size(&self) -> u64 {
|
||||||
if self.size_bytes > 0 {
|
if self.size_bytes > 0 {
|
||||||
return self.size_bytes;
|
return self.size_bytes;
|
||||||
|
|
@ -290,8 +417,9 @@ impl LocalModelRegistry {
|
||||||
pub fn sync_with_featured(&mut self, featured_entries: Vec<LocalModelEntry>) {
|
pub fn sync_with_featured(&mut self, featured_entries: Vec<LocalModelEntry>) {
|
||||||
let mut changed = false;
|
let mut changed = false;
|
||||||
|
|
||||||
for entry in featured_entries {
|
for mut entry in featured_entries {
|
||||||
if !self.models.iter().any(|m| m.id == entry.id) {
|
if !self.models.iter().any(|m| m.id == entry.id) {
|
||||||
|
entry.enrich_with_featured_mmproj();
|
||||||
self.models.push(entry);
|
self.models.push(entry);
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
|
|
@ -309,7 +437,8 @@ impl LocalModelRegistry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_model(&mut self, entry: LocalModelEntry) -> Result<()> {
|
pub fn add_model(&mut self, mut entry: LocalModelEntry) -> Result<()> {
|
||||||
|
entry.enrich_with_featured_mmproj();
|
||||||
if let Some(existing) = self.models.iter_mut().find(|m| m.id == entry.id) {
|
if let Some(existing) = self.models.iter_mut().find(|m| m.id == entry.id) {
|
||||||
*existing = entry;
|
*existing = entry;
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -348,6 +477,10 @@ impl LocalModelRegistry {
|
||||||
pub fn list_models(&self) -> &[LocalModelEntry] {
|
pub fn list_models(&self) -> &[LocalModelEntry] {
|
||||||
&self.models
|
&self.models
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn list_models_mut(&mut self) -> &mut [LocalModelEntry] {
|
||||||
|
&mut self.models
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a unique ID for a model from its repo_id and quantization.
|
/// Generate a unique ID for a model from its repo_id and quantization.
|
||||||
|
|
|
||||||
336
crates/goose/src/providers/local_inference/multimodal.rs
Normal file
336
crates/goose/src/providers/local_inference/multimodal.rs
Normal file
|
|
@ -0,0 +1,336 @@
|
||||||
|
use base64::prelude::*;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::conversation::message::{Message, MessageContent};
|
||||||
|
use crate::providers::errors::ProviderError;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ExtractedImage {
|
||||||
|
pub bytes: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub struct MultimodalMessages {
|
||||||
|
pub messages_json: String,
|
||||||
|
pub images: Vec<ExtractedImage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Walk the OpenAI-format messages JSON array. For each content part with
|
||||||
|
/// `type: "image_url"`, decode the base64 data URL, store the raw bytes,
|
||||||
|
/// and replace the part with `{"type": "text", "text": "<marker>"}`.
|
||||||
|
///
|
||||||
|
/// Returns the modified JSON string and the extracted images in order.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn extract_images_from_messages_json(
|
||||||
|
messages_json: &str,
|
||||||
|
marker: &str,
|
||||||
|
) -> Result<MultimodalMessages, ProviderError> {
|
||||||
|
let mut messages: Vec<Value> = serde_json::from_str(messages_json).map_err(|e| {
|
||||||
|
ProviderError::ExecutionError(format!("Failed to parse messages JSON: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut images = Vec::new();
|
||||||
|
|
||||||
|
for msg in messages.iter_mut() {
|
||||||
|
let Some(content) = msg.get_mut("content").and_then(|c| c.as_array_mut()) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
for part in content.iter_mut() {
|
||||||
|
if part.get("type").and_then(|t| t.as_str()) != Some("image_url") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = part
|
||||||
|
.get("image_url")
|
||||||
|
.and_then(|obj| obj.get("url"))
|
||||||
|
.and_then(|u| u.as_str())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
if url.starts_with("http://") || url.starts_with("https://") {
|
||||||
|
return Err(ProviderError::ExecutionError(
|
||||||
|
"Remote image URLs are not supported with local inference. \
|
||||||
|
Please attach the image directly."
|
||||||
|
.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let base64_data = url.split_once(',').map_or(url, |(_, data)| data);
|
||||||
|
|
||||||
|
let bytes = BASE64_STANDARD.decode(base64_data).map_err(|e| {
|
||||||
|
ProviderError::ExecutionError(format!("Failed to decode base64 image: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
images.push(ExtractedImage { bytes });
|
||||||
|
|
||||||
|
*part = serde_json::json!({
|
||||||
|
"type": "text",
|
||||||
|
"text": marker,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let messages_json = serde_json::to_string(&messages)
|
||||||
|
.map_err(|e| ProviderError::ExecutionError(format!("Failed to serialize messages: {e}")))?;
|
||||||
|
|
||||||
|
Ok(MultimodalMessages {
|
||||||
|
messages_json,
|
||||||
|
images,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Scan messages for `MessageContent::Image` entries. Return the extracted image
|
||||||
|
/// bytes and a new message list with images replaced by text marker placeholders.
|
||||||
|
pub fn extract_images_from_messages(
|
||||||
|
messages: &[Message],
|
||||||
|
marker: &str,
|
||||||
|
) -> (Vec<ExtractedImage>, Vec<Message>) {
|
||||||
|
let mut images = Vec::new();
|
||||||
|
let mut new_messages = Vec::with_capacity(messages.len());
|
||||||
|
|
||||||
|
for msg in messages {
|
||||||
|
let mut new_content = Vec::with_capacity(msg.content.len());
|
||||||
|
for content in &msg.content {
|
||||||
|
match content {
|
||||||
|
MessageContent::Image(img) => {
|
||||||
|
if let Ok(bytes) = BASE64_STANDARD.decode(&img.data) {
|
||||||
|
images.push(ExtractedImage { bytes });
|
||||||
|
new_content.push(MessageContent::text(marker));
|
||||||
|
} else {
|
||||||
|
new_content.push(MessageContent::text(
|
||||||
|
"[Image attached — failed to decode image data]",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => new_content.push(other.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_messages.push(Message {
|
||||||
|
role: msg.role.clone(),
|
||||||
|
content: new_content,
|
||||||
|
..msg.clone()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
(images, new_messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
fn make_test_messages_json(parts: Vec<Value>) -> String {
|
||||||
|
serde_json::to_string(&vec![json!({
|
||||||
|
"role": "user",
|
||||||
|
"content": parts,
|
||||||
|
})])
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tiny_png_base64() -> String {
|
||||||
|
// 1x1 red PNG
|
||||||
|
let bytes: &[u8] = &[
|
||||||
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48,
|
||||||
|
0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00,
|
||||||
|
0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08,
|
||||||
|
0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x36, 0x28, 0x19,
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
|
||||||
|
];
|
||||||
|
BASE64_STANDARD.encode(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_images_replaces_image_url_with_marker() {
|
||||||
|
let b64 = tiny_png_base64();
|
||||||
|
let json = make_test_messages_json(vec![json!({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": format!("data:image/png;base64,{b64}")}
|
||||||
|
})]);
|
||||||
|
|
||||||
|
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||||
|
assert_eq!(result.images.len(), 1);
|
||||||
|
assert!(!result.images[0].bytes.is_empty());
|
||||||
|
|
||||||
|
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||||
|
let content = parsed[0]["content"].as_array().unwrap();
|
||||||
|
assert_eq!(content[0]["type"], "text");
|
||||||
|
assert_eq!(content[0]["text"], "<__media__>");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_images_preserves_text_parts() {
|
||||||
|
let json = make_test_messages_json(vec![json!({
|
||||||
|
"type": "text",
|
||||||
|
"text": "Hello world"
|
||||||
|
})]);
|
||||||
|
|
||||||
|
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||||
|
assert!(result.images.is_empty());
|
||||||
|
|
||||||
|
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||||
|
let content = parsed[0]["content"].as_array().unwrap();
|
||||||
|
assert_eq!(content[0]["type"], "text");
|
||||||
|
assert_eq!(content[0]["text"], "Hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_images_multiple_images() {
|
||||||
|
let b64 = tiny_png_base64();
|
||||||
|
let json = make_test_messages_json(vec![
|
||||||
|
json!({"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}),
|
||||||
|
json!({"type": "text", "text": "describe both"}),
|
||||||
|
json!({"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||||
|
assert_eq!(result.images.len(), 2);
|
||||||
|
|
||||||
|
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||||
|
let content = parsed[0]["content"].as_array().unwrap();
|
||||||
|
assert_eq!(content[0]["type"], "text");
|
||||||
|
assert_eq!(content[0]["text"], "<__media__>");
|
||||||
|
assert_eq!(content[1]["type"], "text");
|
||||||
|
assert_eq!(content[1]["text"], "describe both");
|
||||||
|
assert_eq!(content[2]["type"], "text");
|
||||||
|
assert_eq!(content[2]["text"], "<__media__>");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_images_no_images() {
|
||||||
|
let json = make_test_messages_json(vec![json!({
|
||||||
|
"type": "text",
|
||||||
|
"text": "just text"
|
||||||
|
})]);
|
||||||
|
|
||||||
|
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||||
|
assert!(result.images.is_empty());
|
||||||
|
// JSON should be equivalent
|
||||||
|
let original: Vec<Value> = serde_json::from_str(&json).unwrap();
|
||||||
|
let result_parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||||
|
assert_eq!(original, result_parsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_images_http_url_rejected() {
|
||||||
|
let json = make_test_messages_json(vec![json!({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://example.com/image.png"}
|
||||||
|
})]);
|
||||||
|
|
||||||
|
let result = extract_images_from_messages_json(&json, "<__media__>");
|
||||||
|
assert!(result.is_err());
|
||||||
|
let err = result.unwrap_err().to_string();
|
||||||
|
assert!(err.contains("Remote image URLs are not supported"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_images_mixed_content() {
|
||||||
|
let b64 = tiny_png_base64();
|
||||||
|
// Two messages: first with text+image, second with just text
|
||||||
|
let json = serde_json::to_string(&vec![
|
||||||
|
json!({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is this?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}},
|
||||||
|
]
|
||||||
|
}),
|
||||||
|
json!({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "It looks like a red pixel."},
|
||||||
|
]
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||||
|
assert_eq!(result.images.len(), 1);
|
||||||
|
|
||||||
|
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||||
|
// First message: text preserved, image replaced
|
||||||
|
let content0 = parsed[0]["content"].as_array().unwrap();
|
||||||
|
assert_eq!(content0[0]["text"], "What is this?");
|
||||||
|
assert_eq!(content0[1]["text"], "<__media__>");
|
||||||
|
// Second message unchanged
|
||||||
|
let content1 = parsed[1]["content"].as_array().unwrap();
|
||||||
|
assert_eq!(content1[0]["text"], "It looks like a red pixel.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Tests for extract_images_from_messages (Message-based) ---
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_messages_extract_replaces_image_with_marker() {
|
||||||
|
let b64 = tiny_png_base64();
|
||||||
|
let messages = vec![Message::user().with_image(b64, "image/png")];
|
||||||
|
|
||||||
|
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||||
|
assert_eq!(images.len(), 1);
|
||||||
|
assert!(!images[0].bytes.is_empty());
|
||||||
|
assert_eq!(new_msgs.len(), 1);
|
||||||
|
assert_eq!(new_msgs[0].as_concat_text(), "<__media__>");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_messages_extract_preserves_text() {
|
||||||
|
let messages = vec![Message::user().with_text("Hello world")];
|
||||||
|
|
||||||
|
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||||
|
assert!(images.is_empty());
|
||||||
|
assert_eq!(new_msgs[0].as_concat_text(), "Hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_messages_extract_multiple_images() {
|
||||||
|
let b64 = tiny_png_base64();
|
||||||
|
let messages = vec![Message::user()
|
||||||
|
.with_image(b64.clone(), "image/png")
|
||||||
|
.with_text("describe both")
|
||||||
|
.with_image(b64, "image/png")];
|
||||||
|
|
||||||
|
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||||
|
assert_eq!(images.len(), 2);
|
||||||
|
assert_eq!(new_msgs[0].content.len(), 3);
|
||||||
|
assert_eq!(
|
||||||
|
new_msgs[0].as_concat_text(),
|
||||||
|
"<__media__>\ndescribe both\n<__media__>"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_messages_extract_no_images() {
|
||||||
|
let messages = vec![Message::user().with_text("just text")];
|
||||||
|
|
||||||
|
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||||
|
assert!(images.is_empty());
|
||||||
|
assert_eq!(new_msgs[0].as_concat_text(), "just text");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_messages_extract_invalid_base64() {
|
||||||
|
let messages = vec![Message::user().with_image("not-valid-base64!!!", "image/png")];
|
||||||
|
|
||||||
|
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||||
|
assert!(images.is_empty());
|
||||||
|
assert!(new_msgs[0].as_concat_text().contains("failed to decode"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_messages_extract_mixed_content() {
|
||||||
|
let b64 = tiny_png_base64();
|
||||||
|
let messages = vec![
|
||||||
|
Message::user()
|
||||||
|
.with_text("What is this?")
|
||||||
|
.with_image(b64, "image/png"),
|
||||||
|
Message::assistant().with_text("It looks like a red pixel."),
|
||||||
|
];
|
||||||
|
|
||||||
|
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||||
|
assert_eq!(images.len(), 1);
|
||||||
|
assert_eq!(new_msgs.len(), 2);
|
||||||
|
assert_eq!(new_msgs[0].as_concat_text(), "What is this?\n<__media__>");
|
||||||
|
assert_eq!(new_msgs[1].as_concat_text(), "It looks like a red pixel.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -9,7 +9,11 @@
|
||||||
//!
|
//!
|
||||||
//! Run with a specific model:
|
//! Run with a specific model:
|
||||||
//! TEST_MODEL="bartowski/Qwen_Qwen3-32B-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration -- --ignored
|
//! TEST_MODEL="bartowski/Qwen_Qwen3-32B-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration -- --ignored
|
||||||
|
//!
|
||||||
|
//! Run vision tests (requires a vision-capable model like gemma-4):
|
||||||
|
//! TEST_VISION_MODEL="unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration test_local_inference_vision -- --ignored
|
||||||
|
|
||||||
|
use base64::prelude::*;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use goose::conversation::message::Message;
|
use goose::conversation::message::Message;
|
||||||
use goose::model::ModelConfig;
|
use goose::model::ModelConfig;
|
||||||
|
|
@ -93,3 +97,130 @@ async fn test_local_inference_large_prompt() {
|
||||||
text.len()
|
text.len()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn vision_test_model() -> Option<String> {
|
||||||
|
std::env::var("TEST_VISION_MODEL").ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a small solid-colour 2x2 red PNG as raw bytes.
|
||||||
|
fn tiny_red_png() -> Vec<u8> {
|
||||||
|
vec![
|
||||||
|
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||||
|
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk
|
||||||
|
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, // 1x1
|
||||||
|
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, // RGB, 8-bit
|
||||||
|
0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, // IDAT chunk
|
||||||
|
0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x36, 0x28, 0x19,
|
||||||
|
0x00, // compressed pixel data
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, // IEND
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test that a vision-capable local model can process a message with an embedded image
|
||||||
|
/// and produce a text response without crashing.
|
||||||
|
///
|
||||||
|
/// Requires TEST_VISION_MODEL to be set to a downloaded vision model.
|
||||||
|
/// Example:
|
||||||
|
/// TEST_VISION_MODEL="unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" \
|
||||||
|
/// cargo test -p goose --test local_inference_integration test_local_inference_vision -- --ignored
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn test_local_inference_vision_produces_output() {
|
||||||
|
let model_id = match vision_test_model() {
|
||||||
|
Some(id) => id,
|
||||||
|
None => {
|
||||||
|
eprintln!(
|
||||||
|
"Skipping vision test: TEST_VISION_MODEL not set. \
|
||||||
|
Set it to a vision-capable model like unsloth/gemma-4-E4B-it-GGUF:Q4_K_M"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_config = ModelConfig::new(&model_id).expect("valid model config");
|
||||||
|
let provider = create("local", model_config.clone(), Vec::new())
|
||||||
|
.await
|
||||||
|
.expect("provider creation should succeed");
|
||||||
|
|
||||||
|
let image_bytes = tiny_red_png();
|
||||||
|
let image_b64 = BASE64_STANDARD.encode(&image_bytes);
|
||||||
|
|
||||||
|
let system = "You are a helpful assistant. Describe images briefly.";
|
||||||
|
let messages = vec![Message::user()
|
||||||
|
.with_text("What color is this image?")
|
||||||
|
.with_image(image_b64, "image/png")];
|
||||||
|
|
||||||
|
let mut stream = provider
|
||||||
|
.stream(&model_config, "test-vision-session", system, &messages, &[])
|
||||||
|
.await
|
||||||
|
.expect("stream should start for vision input");
|
||||||
|
|
||||||
|
let mut got_text = false;
|
||||||
|
let mut collected_text = String::new();
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
let (msg, _usage) = result.expect("stream item should be Ok");
|
||||||
|
if let Some(m) = msg {
|
||||||
|
got_text = true;
|
||||||
|
collected_text.push_str(&m.as_concat_text());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
got_text,
|
||||||
|
"vision stream should produce at least one text message"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!collected_text.is_empty(),
|
||||||
|
"vision response should contain text"
|
||||||
|
);
|
||||||
|
println!("Vision response: {collected_text}");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test that sending an image to a text-only model produces a clear error
|
||||||
|
/// rather than crashing.
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore]
|
||||||
|
async fn test_local_inference_vision_text_only_model_graceful() {
|
||||||
|
let model_config = ModelConfig::new(&test_model()).expect("valid model config");
|
||||||
|
let provider = create("local", model_config.clone(), Vec::new())
|
||||||
|
.await
|
||||||
|
.expect("provider creation should succeed");
|
||||||
|
|
||||||
|
let image_bytes = tiny_red_png();
|
||||||
|
let image_b64 = BASE64_STANDARD.encode(&image_bytes);
|
||||||
|
|
||||||
|
let system = "You are a helpful assistant.";
|
||||||
|
let messages = vec![Message::user()
|
||||||
|
.with_text("What is this?")
|
||||||
|
.with_image(image_b64, "image/png")];
|
||||||
|
|
||||||
|
let mut stream = provider
|
||||||
|
.stream(&model_config, "test-session", system, &messages, &[])
|
||||||
|
.await
|
||||||
|
.expect("stream should start");
|
||||||
|
|
||||||
|
// The stream should either produce a response with the image stripped
|
||||||
|
// (placeholder text) or produce an error — but it must not crash.
|
||||||
|
let mut completed = false;
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(_) => completed = true,
|
||||||
|
Err(e) => {
|
||||||
|
// An error about missing vision support is acceptable
|
||||||
|
let err_msg = e.to_string();
|
||||||
|
assert!(
|
||||||
|
err_msg.contains("vision") || err_msg.contains("image"),
|
||||||
|
"error should mention vision/image support, got: {err_msg}"
|
||||||
|
);
|
||||||
|
completed = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
completed,
|
||||||
|
"stream should complete without crashing when images sent to text-only model"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -317,6 +317,31 @@ prompt: |
|
||||||
Log results to: {{ workspace_dir }}/phase3_delegation.md
|
Log results to: {{ workspace_dir }}/phase3_delegation.md
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
{% if test_phases == "all" or "vision" in test_phases %}
|
||||||
|
## 📷 PHASE 3B: Local Inference Vision Testing
|
||||||
|
|
||||||
|
**Prerequisites**: A vision-capable local model must be downloaded (e.g., gemma-4-E4B).
|
||||||
|
Skip this phase if no local vision model is available.
|
||||||
|
|
||||||
|
### Vision Smoke Test
|
||||||
|
1. Create a small test image:
|
||||||
|
```
|
||||||
|
python3 -c "import struct, zlib; raw=b'\x00\xff\x00\x00'; d=zlib.compress(raw); ihdr=b'\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00'; print('Created test.png')"
|
||||||
|
```
|
||||||
|
Or simply create a 1-pixel PNG test image using available tools.
|
||||||
|
2. Verify the test image file exists and is valid.
|
||||||
|
3. Send a message to the local vision model referencing the test image.
|
||||||
|
4. Verify the model responds with text (not an error or crash).
|
||||||
|
5. Verify the response acknowledges the image content.
|
||||||
|
|
||||||
|
### Vision Error Handling Test
|
||||||
|
1. If a text-only local model is available, send it a message with an image attached.
|
||||||
|
2. Verify it responds gracefully (either with a placeholder message or a clear error),
|
||||||
|
not with a crash or FFI error.
|
||||||
|
|
||||||
|
Log results to: {{ workspace_dir }}/phase3b_vision.md
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% if test_phases == "all" or "advanced" in test_phases %}
|
{% if test_phases == "all" or "advanced" in test_phases %}
|
||||||
## 🔬 PHASE 4: Advanced Testing
|
## 🔬 PHASE 4: Advanced Testing
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5724,7 +5724,8 @@
|
||||||
"size_bytes",
|
"size_bytes",
|
||||||
"status",
|
"status",
|
||||||
"recommended",
|
"recommended",
|
||||||
"settings"
|
"settings",
|
||||||
|
"vision_capable"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"filename": {
|
"filename": {
|
||||||
|
|
@ -5733,6 +5734,14 @@
|
||||||
"id": {
|
"id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
"mmproj_status": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ModelDownloadStatus"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"quantization": {
|
"quantization": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
|
@ -5752,6 +5761,9 @@
|
||||||
},
|
},
|
||||||
"status": {
|
"status": {
|
||||||
"$ref": "#/components/schemas/ModelDownloadStatus"
|
"$ref": "#/components/schemas/ModelDownloadStatus"
|
||||||
|
},
|
||||||
|
"vision_capable": {
|
||||||
|
"type": "boolean"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -6497,11 +6509,22 @@
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"format": "float"
|
"format": "float"
|
||||||
},
|
},
|
||||||
|
"image_token_estimate": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Estimated tokens per image for budget planning before mtmd tokenization.\nThe actual count is determined after tokenization via `chunks.total_tokens()`.",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
"max_output_tokens": {
|
"max_output_tokens": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"nullable": true,
|
"nullable": true,
|
||||||
"minimum": 0
|
"minimum": 0
|
||||||
},
|
},
|
||||||
|
"mmproj_size_bytes": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"description": "Size of the mmproj file in bytes, used for memory accounting.",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
"n_batch": {
|
"n_batch": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int32",
|
"format": "int32",
|
||||||
|
|
@ -6542,6 +6565,10 @@
|
||||||
},
|
},
|
||||||
"use_mlock": {
|
"use_mlock": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"vision_capable": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether this model architecture supports vision input.\nDerived from the featured model table, not user-configurable."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -612,12 +612,14 @@ export type LoadedProvider = {
|
||||||
export type LocalModelResponse = {
|
export type LocalModelResponse = {
|
||||||
filename: string;
|
filename: string;
|
||||||
id: string;
|
id: string;
|
||||||
|
mmproj_status?: ModelDownloadStatus | null;
|
||||||
quantization: string;
|
quantization: string;
|
||||||
recommended: boolean;
|
recommended: boolean;
|
||||||
repo_id: string;
|
repo_id: string;
|
||||||
settings: ModelSettings;
|
settings: ModelSettings;
|
||||||
size_bytes: number;
|
size_bytes: number;
|
||||||
status: ModelDownloadStatus;
|
status: ModelDownloadStatus;
|
||||||
|
vision_capable: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -821,7 +823,16 @@ export type ModelSettings = {
|
||||||
enable_thinking?: boolean;
|
enable_thinking?: boolean;
|
||||||
flash_attention?: boolean | null;
|
flash_attention?: boolean | null;
|
||||||
frequency_penalty?: number;
|
frequency_penalty?: number;
|
||||||
|
/**
|
||||||
|
* Estimated tokens per image for budget planning before mtmd tokenization.
|
||||||
|
* The actual count is determined after tokenization via `chunks.total_tokens()`.
|
||||||
|
*/
|
||||||
|
image_token_estimate?: number;
|
||||||
max_output_tokens?: number | null;
|
max_output_tokens?: number | null;
|
||||||
|
/**
|
||||||
|
* Size of the mmproj file in bytes, used for memory accounting.
|
||||||
|
*/
|
||||||
|
mmproj_size_bytes?: number;
|
||||||
n_batch?: number | null;
|
n_batch?: number | null;
|
||||||
n_gpu_layers?: number | null;
|
n_gpu_layers?: number | null;
|
||||||
n_threads?: number | null;
|
n_threads?: number | null;
|
||||||
|
|
@ -832,6 +843,11 @@ export type ModelSettings = {
|
||||||
sampling?: SamplingConfig;
|
sampling?: SamplingConfig;
|
||||||
use_jinja?: boolean;
|
use_jinja?: boolean;
|
||||||
use_mlock?: boolean;
|
use_mlock?: boolean;
|
||||||
|
/**
|
||||||
|
* Whether this model architecture supports vision input.
|
||||||
|
* Derived from the featured model table, not user-configurable.
|
||||||
|
*/
|
||||||
|
vision_capable?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ModelTemplate = {
|
export type ModelTemplate = {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||||
import { Download, Trash2, X, ChevronDown, ChevronUp, Settings2 } from 'lucide-react';
|
import { Download, Trash2, X, ChevronDown, ChevronUp, Settings2, Eye } from 'lucide-react';
|
||||||
import { Button } from '../../ui/button';
|
import { Button } from '../../ui/button';
|
||||||
import { useModelAndProvider } from '../../ModelAndProviderContext';
|
import { useModelAndProvider } from '../../ModelAndProviderContext';
|
||||||
import { defineMessages, useIntl } from '../../../i18n';
|
import { defineMessages, useIntl } from '../../../i18n';
|
||||||
|
|
@ -83,8 +83,57 @@ const i18n = defineMessages({
|
||||||
id: 'localInferenceSettings.modelSettingsTitle',
|
id: 'localInferenceSettings.modelSettingsTitle',
|
||||||
defaultMessage: 'Model settings',
|
defaultMessage: 'Model settings',
|
||||||
},
|
},
|
||||||
|
vision: {
|
||||||
|
id: 'localInferenceSettings.vision',
|
||||||
|
defaultMessage: 'Vision',
|
||||||
|
},
|
||||||
|
visionEncoderDownloading: {
|
||||||
|
id: 'localInferenceSettings.visionEncoderDownloading',
|
||||||
|
defaultMessage: 'Vision encoder downloading…',
|
||||||
|
},
|
||||||
|
visionEncoderNotDownloaded: {
|
||||||
|
id: 'localInferenceSettings.visionEncoderNotDownloaded',
|
||||||
|
defaultMessage: 'Vision encoder not downloaded',
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const VisionBadge = ({ model, intl }: { model: LocalModelResponse; intl: ReturnType<typeof useIntl> }) => {
|
||||||
|
if (!model.vision_capable) return null;
|
||||||
|
|
||||||
|
const mmproj = model.mmproj_status;
|
||||||
|
const isDownloaded = mmproj?.state === 'Downloaded';
|
||||||
|
const isDownloading = mmproj?.state === 'Downloading';
|
||||||
|
|
||||||
|
if (isDownloaded) {
|
||||||
|
return (
|
||||||
|
<span className="inline-flex items-center gap-1 text-xs text-green-400 bg-green-500/10 px-2 py-0.5 rounded">
|
||||||
|
<Eye className="w-3 h-3" />
|
||||||
|
{intl.formatMessage(i18n.vision)}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isDownloading) {
|
||||||
|
const percent = mmproj && 'progress_percent' in mmproj
|
||||||
|
? Math.round(mmproj.progress_percent)
|
||||||
|
: null;
|
||||||
|
return (
|
||||||
|
<span className="inline-flex items-center gap-1 text-xs text-yellow-400 bg-yellow-500/10 px-2 py-0.5 rounded">
|
||||||
|
<Eye className="w-3 h-3" />
|
||||||
|
{intl.formatMessage(i18n.visionEncoderDownloading)}
|
||||||
|
{percent != null && ` ${percent}%`}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<span className="inline-flex items-center gap-1 text-xs text-text-muted bg-background-subtle px-2 py-0.5 rounded">
|
||||||
|
<Eye className="w-3 h-3" />
|
||||||
|
{intl.formatMessage(i18n.vision)}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
const formatBytes = (bytes: number): string => {
|
const formatBytes = (bytes: number): string => {
|
||||||
if (bytes < 1024) return `${bytes}B`;
|
if (bytes < 1024) return `${bytes}B`;
|
||||||
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)}KB`;
|
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)}KB`;
|
||||||
|
|
@ -128,6 +177,19 @@ export const LocalInferenceSettings = () => {
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// Poll model list while any vision encoder is downloading
|
||||||
|
useEffect(() => {
|
||||||
|
const hasDownloadingMmproj = models.some(
|
||||||
|
(m) => m.vision_capable && m.mmproj_status?.state === 'Downloading'
|
||||||
|
);
|
||||||
|
if (!hasDownloadingMmproj) return;
|
||||||
|
|
||||||
|
const interval = setInterval(() => {
|
||||||
|
loadModels();
|
||||||
|
}, 2000);
|
||||||
|
return () => clearInterval(interval);
|
||||||
|
}, [models, loadModels]);
|
||||||
|
|
||||||
const selectModel = async (modelId: string) => {
|
const selectModel = async (modelId: string) => {
|
||||||
try {
|
try {
|
||||||
await setConfigProvider({
|
await setConfigProvider({
|
||||||
|
|
@ -365,6 +427,7 @@ export const LocalInferenceSettings = () => {
|
||||||
{intl.formatMessage(i18n.recommended)}
|
{intl.formatMessage(i18n.recommended)}
|
||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
|
<VisionBadge model={model} intl={intl} />
|
||||||
</div>
|
</div>
|
||||||
<div className="flex items-center gap-1">
|
<div className="flex items-center gap-1">
|
||||||
<Button
|
<Button
|
||||||
|
|
@ -414,6 +477,7 @@ export const LocalInferenceSettings = () => {
|
||||||
{intl.formatMessage(i18n.recommended)}
|
{intl.formatMessage(i18n.recommended)}
|
||||||
</span>
|
</span>
|
||||||
)}
|
)}
|
||||||
|
<VisionBadge model={model} intl={intl} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<Button
|
<Button
|
||||||
|
|
|
||||||
|
|
@ -1838,6 +1838,15 @@
|
||||||
"localInferenceSettings.title": {
|
"localInferenceSettings.title": {
|
||||||
"defaultMessage": "Local Inference Models"
|
"defaultMessage": "Local Inference Models"
|
||||||
},
|
},
|
||||||
|
"localInferenceSettings.vision": {
|
||||||
|
"defaultMessage": "Vision"
|
||||||
|
},
|
||||||
|
"localInferenceSettings.visionEncoderDownloading": {
|
||||||
|
"defaultMessage": "Vision encoder downloading\u2026"
|
||||||
|
},
|
||||||
|
"localInferenceSettings.visionEncoderNotDownloaded": {
|
||||||
|
"defaultMessage": "Vision encoder not downloaded"
|
||||||
|
},
|
||||||
"localModelManager.active": {
|
"localModelManager.active": {
|
||||||
"defaultMessage": "Active"
|
"defaultMessage": "Active"
|
||||||
},
|
},
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue