From de317d5445ce00e8acda99a2597dc6bf08cf7f9b Mon Sep 17 00:00:00 2001 From: jh-block Date: Mon, 13 Apr 2026 10:17:04 +0200 Subject: [PATCH] Add vision/image support for local inference models (#8442) Signed-off-by: jh-block --- crates/goose-cli/src/cli.rs | 3 + .../src/routes/local_inference.rs | 157 +++++++- crates/goose/Cargo.toml | 4 +- crates/goose/src/providers/local_inference.rs | 151 ++++++-- .../inference_emulated_tools.rs | 50 +-- .../local_inference/inference_engine.rs | 92 ++++- .../local_inference/inference_native_tools.rs | 63 ++-- .../local_inference/local_model_registry.rs | 137 ++++++- .../providers/local_inference/multimodal.rs | 336 ++++++++++++++++++ .../tests/local_inference_integration.rs | 131 +++++++ goose-self-test.yaml | 25 ++ ui/desktop/openapi.json | 29 +- ui/desktop/src/api/types.gen.ts | 16 + .../localInference/LocalInferenceSettings.tsx | 66 +++- ui/desktop/src/i18n/messages/en.json | 9 + 15 files changed, 1181 insertions(+), 88 deletions(-) create mode 100644 crates/goose/src/providers/local_inference/multimodal.rs diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index d43a3425f6..c0ac195213 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -1604,6 +1604,9 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()> source_url: file.download_url.clone(), settings: Default::default(), size_bytes: file.size_bytes, + mmproj_path: None, + mmproj_source_url: None, + mmproj_size_bytes: 0, }; { diff --git a/crates/goose-server/src/routes/local_inference.rs b/crates/goose-server/src/routes/local_inference.rs index ddad43ff98..851e0795aa 100644 --- a/crates/goose-server/src/routes/local_inference.rs +++ b/crates/goose-server/src/routes/local_inference.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use crate::routes::errors::ErrorResponse; use crate::state::AppState; use axum::{ @@ -13,9 +15,9 @@ use goose::providers::local_inference::{ available_inference_memory_bytes, hf_models::{resolve_model_spec, HfGgufFile}, local_model_registry::{ - default_settings_for_model, get_registry, is_featured_model, model_id_from_repo, - LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus, ModelSettings, - FEATURED_MODELS, + default_settings_for_model, featured_mmproj_spec, get_registry, is_featured_model, + model_id_from_repo, LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus, + ModelSettings, FEATURED_MODELS, }, recommend_local_model, }; @@ -47,10 +49,14 @@ pub struct LocalModelResponse { pub status: ModelDownloadStatus, pub recommended: bool, pub settings: ModelSettings, + pub vision_capable: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub mmproj_status: Option, } async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> { let mut entries_to_add = Vec::new(); + let mut mmproj_downloads_needed: Vec<(String, String, PathBuf)> = Vec::new(); for featured in FEATURED_MODELS { let (repo_id, quantization) = match hf_models::parse_model_spec(featured.spec) { @@ -64,8 +70,27 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> { let registry = get_registry() .lock() .map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?; - if registry.has_model(&model_id) { - continue; + if let Some(existing) = registry.get_model(&model_id) { + let needs_backfill = existing.mmproj_path.is_none() && featured.mmproj.is_some(); + let needs_download = existing.is_downloaded() + && featured.mmproj.is_some() + && !existing.mmproj_path.as_ref().is_some_and(|p| p.exists()); + + if needs_download { + if let Some(mmproj) = featured.mmproj.as_ref() { + let path = mmproj.local_path(); + let url = format!( + "https://huggingface.co/{}/resolve/main/{}", + mmproj.repo, mmproj.filename + ); + mmproj_downloads_needed.push((model_id.clone(), url, path)); + } + } + + if !needs_backfill { + continue; + } + // Fall through to build the entry for sync_with_featured backfill } } @@ -91,6 +116,8 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> { let local_path = Paths::in_data_dir("models").join(&hf_file.filename); + // enrich_with_featured_mmproj is called by sync_with_featured/add_model, + // so we don't need to populate mmproj fields here. entries_to_add.push(LocalModelEntry { id: model_id.clone(), repo_id, @@ -100,14 +127,58 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> { source_url: hf_file.download_url, settings: default_settings_for_model(&model_id), size_bytes: hf_file.size_bytes, + mmproj_path: None, + mmproj_source_url: None, + mmproj_size_bytes: 0, }); } - if !entries_to_add.is_empty() { + { let mut registry = get_registry() .lock() .map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?; - registry.sync_with_featured(entries_to_add); + + if !entries_to_add.is_empty() { + registry.sync_with_featured(entries_to_add); + } + + // Backfill mmproj data for all registry models and collect any + // needed mmproj downloads for models already on disk. + for model in registry.list_models_mut() { + model.enrich_with_featured_mmproj(); + if model.is_downloaded() { + if let Some(mmproj) = featured_mmproj_spec(&model.id) { + let path = mmproj.local_path(); + if !path.exists() { + let url = format!( + "https://huggingface.co/{}/resolve/main/{}", + mmproj.repo, mmproj.filename + ); + mmproj_downloads_needed.push((model.id.clone(), url, path)); + } + } + } + } + let _ = registry.save(); + } + + // Auto-download mmproj files for models that are already downloaded. + // Deduplicate by path since multiple quants share one mmproj file. + let dm = get_download_manager(); + let mut started_paths = std::collections::HashSet::new(); + for (model_id, url, path) in mmproj_downloads_needed { + if !path.exists() && started_paths.insert(path.clone()) { + let download_id = format!("{}-mmproj", model_id); + let dominated_by_active = dm + .get_progress(&download_id) + .is_some_and(|p| p.status == goose::download_manager::DownloadStatus::Downloading); + if !dominated_by_active { + tracing::info!(model_id = %model_id, "Auto-downloading vision encoder for existing model"); + if let Err(e) = dm.download_model(download_id, url, path, None).await { + tracing::warn!(model_id = %model_id, error = %e, "Failed to start mmproj download"); + } + } + } } Ok(()) @@ -154,6 +225,28 @@ pub async fn list_local_models( let size_bytes = entry.file_size(); + let vision_capable = entry.settings.vision_capable; + let mmproj_status = if vision_capable { + let ms = entry.mmproj_download_status(); + Some(match ms { + RegistryDownloadStatus::NotDownloaded => ModelDownloadStatus::NotDownloaded, + RegistryDownloadStatus::Downloading { + progress_percent, + bytes_downloaded, + total_bytes, + speed_bps, + } => ModelDownloadStatus::Downloading { + progress_percent, + bytes_downloaded, + total_bytes, + speed_bps: Some(speed_bps), + }, + RegistryDownloadStatus::Downloaded => ModelDownloadStatus::Downloaded, + }) + } else { + None + }; + models.push(LocalModelResponse { id: entry.id.clone(), repo_id: entry.repo_id.clone(), @@ -163,6 +256,8 @@ pub async fn list_local_models( status, recommended: recommended_id == entry.id, settings: entry.settings.clone(), + vision_capable, + mmproj_status, }); } @@ -276,16 +371,26 @@ pub async fn download_hf_model( source_url: download_url.clone(), settings: default_settings_for_model(&model_id), size_bytes: hf_file.size_bytes, + mmproj_path: None, + mmproj_source_url: None, + mmproj_size_bytes: 0, }; - { + // add_model enriches the entry with mmproj metadata from the featured table + let mmproj_path = { let mut registry = get_registry() .lock() .map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?; registry .add_model(entry) .map_err(|e| ErrorResponse::internal(format!("{}", e)))?; - } + registry.get_model(&model_id).and_then(|e| { + e.mmproj_path + .as_ref() + .zip(e.mmproj_source_url.as_ref()) + .map(|(p, u)| (p.clone(), u.clone())) + }) + }; let dm = get_download_manager(); dm.download_model( @@ -297,6 +402,19 @@ pub async fn download_hf_model( .await .map_err(|e| ErrorResponse::internal(format!("Download failed: {}", e)))?; + if let Some((mmproj_path, mmproj_url)) = mmproj_path { + if !mmproj_path.exists() { + dm.download_model( + format!("{}-mmproj", model_id), + mmproj_url, + mmproj_path, + None, + ) + .await + .map_err(|e| ErrorResponse::internal(format!("mmproj download failed: {}", e)))?; + } + } + Ok((StatusCode::ACCEPTED, Json(model_id))) } @@ -338,6 +456,7 @@ pub async fn cancel_local_model_download( manager .cancel_download(&format!("{}-model", model_id)) .map_err(|e| ErrorResponse::internal(format!("{}", e)))?; + let _ = manager.cancel_download(&format!("{}-mmproj", model_id)); Ok(StatusCode::OK) } @@ -351,14 +470,22 @@ pub async fn cancel_local_model_download( ) )] pub async fn delete_local_model(Path(model_id): Path) -> Result { - let local_path = { + let (local_path, mmproj_path, other_uses_mmproj) = { let registry = get_registry() .lock() .map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?; let entry = registry .get_model(&model_id) .ok_or_else(|| ErrorResponse::not_found("Model not found"))?; - entry.local_path.clone() + let lp = entry.local_path.clone(); + let mp = entry.mmproj_path.clone(); + // Check if another downloaded model shares this mmproj file + let shared = mp.as_ref().is_some_and(|target| { + registry.list_models().iter().any(|m| { + m.id != model_id && m.is_downloaded() && m.mmproj_path.as_ref() == Some(target) + }) + }); + (lp, mp, shared) }; if local_path.exists() { @@ -367,6 +494,14 @@ pub async fn delete_local_model(Path(model_id): Path) -> Result Option<( - PathBuf, - usize, - crate::providers::local_inference::local_model_registry::ModelSettings, -)> { +pub struct ResolvedModelPaths { + pub model_path: PathBuf, + pub context_limit: usize, + pub settings: crate::providers::local_inference::local_model_registry::ModelSettings, + pub mmproj_path: Option, +} + +/// Resolve model path, context limit, settings, and mmproj path for a model ID from the registry. +pub fn resolve_model_path(model_id: &str) -> Option { use crate::providers::local_inference::local_model_registry::{ default_settings_for_model, get_registry, }; @@ -135,7 +138,15 @@ pub fn resolve_model_path( // recognized (or with a different quantization) still get the right behavior. let defaults = default_settings_for_model(model_id); settings.native_tool_calling = defaults.native_tool_calling; - return Some((entry.local_path.clone(), ctx, settings)); + settings.vision_capable = defaults.vision_capable; + settings.mmproj_size_bytes = entry.mmproj_size_bytes; + let mmproj_path = entry.mmproj_path.as_ref().filter(|p| p.exists()).cloned(); + return Some(ResolvedModelPaths { + model_path: entry.local_path.clone(), + context_limit: ctx, + settings, + mmproj_path, + }); } } @@ -208,9 +219,33 @@ fn build_openai_messages_json(system: &str, messages: &[Message]) -> String { let mut arr: Vec = vec![json!({"role": "system", "content": system})]; arr.extend(format_messages(messages, &ImageFormat::OpenAi)); + strip_image_parts_from_messages(&mut arr); serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string()) } +/// Remove `image_url` content parts from OpenAI-format messages JSON, replacing +/// each with a text note. This prevents an FFI crash in llama.cpp which does not +/// accept `image_url` content-part types. +fn strip_image_parts_from_messages(messages: &mut [Value]) { + let mut stripped = false; + for msg in messages.iter_mut() { + if let Some(content) = msg.get_mut("content").and_then(|c| c.as_array_mut()) { + for part in content.iter_mut() { + if part.get("type").and_then(|t| t.as_str()) == Some("image_url") { + *part = json!({ + "type": "text", + "text": "[Image attached — image input is not supported with the currently selected model]" + }); + stripped = true; + } + } + } + } + if stripped { + tracing::warn!("Stripped image content parts from messages — vision encoder not available for this model"); + } +} + /// Convert a message into plain text for the emulator path's chat history. /// /// This is the emulator-path counterpart of [`format_messages`] used by the native @@ -269,6 +304,12 @@ fn extract_text_content(msg: &Message) -> String { parts.push(format!("Command error: {}", e)); } }, + MessageContent::Image(_) => { + parts.push( + "[Image attached — image input is not supported with the currently selected model]" + .to_string(), + ); + } _ => {} } } @@ -331,8 +372,9 @@ impl LocalInferenceProvider { model_id: &str, settings: &crate::providers::local_inference::local_model_registry::ModelSettings, ) -> Result { - let (model_path, _context_limit, _) = resolve_model_path(model_id) + let resolved = resolve_model_path(model_id) .ok_or_else(|| ProviderError::ExecutionError(format!("Unknown model: {}", model_id)))?; + let model_path = resolved.model_path; if !model_path.exists() { return Err(ProviderError::ExecutionError(format!( @@ -368,9 +410,49 @@ impl LocalInferenceProvider { } }; + let mtmd_ctx = Self::init_mtmd_context(&model, &resolved.mmproj_path, settings); + tracing::info!(model_id = model_id, "Model loaded successfully"); - Ok(LoadedModel { model, template }) + Ok(LoadedModel { + model, + template, + mtmd_ctx, + }) + } + + fn init_mtmd_context( + model: &LlamaModel, + mmproj_path: &Option, + settings: &crate::providers::local_inference::local_model_registry::ModelSettings, + ) -> Option { + use llama_cpp_2::mtmd::{MtmdContext, MtmdContextParams}; + + let mmproj_path = mmproj_path.as_ref().filter(|p| p.exists())?; + + let params = MtmdContextParams { + use_gpu: true, + n_threads: settings + .n_threads + .unwrap_or_else(|| MtmdContextParams::default().n_threads), + ..MtmdContextParams::default() + }; + + match MtmdContext::init_from_file(mmproj_path.to_str().unwrap_or_default(), model, ¶ms) + { + Ok(ctx) => { + tracing::info!( + vision = ctx.support_vision(), + audio = ctx.support_audio(), + "Multimodal context initialized" + ); + Some(ctx) + } + Err(e) => { + tracing::warn!(error = %e, "Failed to init multimodal context"); + None + } + } } } @@ -453,13 +535,11 @@ impl Provider for LocalInferenceProvider { messages: &[Message], tools: &[Tool], ) -> Result { - let (_model_path, model_context_limit, model_settings) = - resolve_model_path(&model_config.model_name).ok_or_else(|| { - ProviderError::ExecutionError(format!( - "Model not found: {}", - model_config.model_name - )) - })?; + let resolved = resolve_model_path(&model_config.model_name).ok_or_else(|| { + ProviderError::ExecutionError(format!("Model not found: {}", model_config.model_name)) + })?; + let model_context_limit = resolved.context_limit; + let model_settings = resolved.settings; // Ensure model is loaded — unload any other models first to free memory. { @@ -503,6 +583,18 @@ impl Provider for LocalInferenceProvider { system.to_string() }; + // Extract images for vision-capable models, replacing them with markers. + // For non-vision models, leave messages unchanged (existing strip logic handles them). + let has_vision = resolved.mmproj_path.is_some(); + let marker = llama_cpp_2::mtmd::mtmd_default_marker(); + let (images, vision_messages): (Vec, Option>) = if has_vision { + let (imgs, msgs) = multimodal::extract_images_from_messages(messages, marker); + (imgs, Some(msgs)) + } else { + (Vec::new(), None) + }; + let effective_messages: &[Message] = vision_messages.as_deref().unwrap_or(messages); + // Build chat messages for the template let mut chat_messages = vec![ @@ -529,7 +621,7 @@ impl Provider for LocalInferenceProvider { })?]; } - for msg in messages { + for msg in effective_messages { let role = match msg.role { Role::User => "user", Role::Assistant => "assistant", @@ -553,7 +645,10 @@ impl Provider for LocalInferenceProvider { }; let oai_messages_json = if model_settings.use_jinja || native_tool_calling { - Some(build_openai_messages_json(&system_prompt, messages)) + Some(build_openai_messages_json( + &system_prompt, + effective_messages, + )) } else { None }; @@ -563,6 +658,7 @@ impl Provider for LocalInferenceProvider { let model_name = model_config.model_name.clone(); let context_limit = model_context_limit; let settings = model_settings; + let mmproj_path = resolved.mmproj_path.clone(); let log_payload = serde_json::json!({ "system": &system_prompt, @@ -604,8 +700,8 @@ impl Provider for LocalInferenceProvider { }}; } - let model_guard = model_arc.blocking_lock(); - let loaded = match model_guard.as_ref() { + let mut model_guard = model_arc.blocking_lock(); + let loaded = match model_guard.as_mut() { Some(l) => l, None => { send_err!(ProviderError::ExecutionError( @@ -614,6 +710,16 @@ impl Provider for LocalInferenceProvider { } }; + // Lazily initialize the multimodal context if the vision encoder + // was downloaded after the model was loaded. + if !images.is_empty() && loaded.mtmd_ctx.is_none() { + loaded.mtmd_ctx = LocalInferenceProvider::init_mtmd_context( + &loaded.model, + &mmproj_path, + &settings, + ); + } + let message_id = Uuid::new_v4().to_string(); let mut gen_ctx = GenerationContext { @@ -626,6 +732,7 @@ impl Provider for LocalInferenceProvider { message_id: &message_id, tx: &tx, log: &mut log, + images: &images, }; let result = if use_emulator { diff --git a/crates/goose/src/providers/local_inference/inference_emulated_tools.rs b/crates/goose/src/providers/local_inference/inference_emulated_tools.rs index 9e04852a4a..90732de690 100644 --- a/crates/goose/src/providers/local_inference/inference_emulated_tools.rs +++ b/crates/goose/src/providers/local_inference/inference_emulated_tools.rs @@ -29,8 +29,8 @@ use std::borrow::Cow; use uuid::Uuid; use super::inference_engine::{ - create_and_prefill_context, generation_loop, validate_and_compute_context, GenerationContext, - TokenAction, + create_and_prefill_context, create_and_prefill_multimodal, generation_loop, + validate_and_compute_context, GenerationContext, TokenAction, }; use super::{finalize_usage, StreamSender, CODE_EXECUTION_TOOL, SHELL_TOOL}; @@ -370,26 +370,32 @@ pub(super) fn generate_with_emulated_tools( ProviderError::ExecutionError(format!("Failed to apply chat template: {}", e)) })?; - let tokens = ctx - .loaded - .model - .str_to_token(&prompt, AddBos::Never) - .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; - - let (prompt_token_count, effective_ctx) = validate_and_compute_context( - ctx.loaded, - ctx.runtime, - tokens.len(), - ctx.context_limit, - ctx.settings, - )?; - let mut llama_ctx = create_and_prefill_context( - ctx.loaded, - ctx.runtime, - &tokens, - effective_ctx, - ctx.settings, - )?; + let (mut llama_ctx, prompt_token_count, effective_ctx) = if !ctx.images.is_empty() { + create_and_prefill_multimodal( + ctx.loaded, + ctx.runtime, + &prompt, + ctx.images, + ctx.context_limit, + ctx.settings, + )? + } else { + let tokens = ctx + .loaded + .model + .str_to_token(&prompt, AddBos::Never) + .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; + let (ptc, ectx) = validate_and_compute_context( + ctx.loaded, + ctx.runtime, + tokens.len(), + ctx.context_limit, + ctx.settings, + )?; + let lctx = + create_and_prefill_context(ctx.loaded, ctx.runtime, &tokens, ectx, ctx.settings)?; + (lctx, ptc, ectx) + }; let message_id = ctx.message_id; let tx = ctx.tx; diff --git a/crates/goose/src/providers/local_inference/inference_engine.rs b/crates/goose/src/providers/local_inference/inference_engine.rs index 23ec84f293..f9ae886b7a 100644 --- a/crates/goose/src/providers/local_inference/inference_engine.rs +++ b/crates/goose/src/providers/local_inference/inference_engine.rs @@ -1,9 +1,11 @@ use crate::providers::errors::ProviderError; use crate::providers::local_inference::local_model_registry::ModelSettings; +use crate::providers::local_inference::multimodal::ExtractedImage; use crate::providers::utils::RequestLog; use llama_cpp_2::context::params::LlamaContextParams; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel}; +use llama_cpp_2::mtmd::{MtmdBitmap, MtmdContext, MtmdInputText}; use llama_cpp_2::sampling::LlamaSampler; use std::num::NonZeroU32; @@ -19,11 +21,14 @@ pub(super) struct GenerationContext<'a> { pub message_id: &'a str, pub tx: &'a StreamSender, pub log: &'a mut RequestLog, + pub images: &'a [ExtractedImage], } pub(super) struct LoadedModel { pub model: LlamaModel, pub template: LlamaChatTemplate, + /// Multimodal context for vision models. None for text-only models. + pub mtmd_ctx: Option, } /// Estimate the maximum context length that can fit in available accelerator/CPU @@ -33,11 +38,13 @@ pub(super) struct LoadedModel { pub(super) fn estimate_max_context_for_memory( model: &LlamaModel, runtime: &InferenceRuntime, + mmproj_overhead_bytes: u64, ) -> Option { - let available = super::available_inference_memory_bytes(runtime); - if available == 0 { + let raw_available = super::available_inference_memory_bytes(runtime); + if raw_available == 0 { return None; } + let available = raw_available.saturating_sub(mmproj_overhead_bytes); // Reserve memory for computation scratch buffers (attention, etc.) and other overhead. // The compute buffer can be 40-50% of the KV cache size for large models, so we @@ -209,7 +216,12 @@ pub(super) fn validate_and_compute_context( settings: &crate::providers::local_inference::local_model_registry::ModelSettings, ) -> Result<(usize, usize), ProviderError> { let n_ctx_train = loaded.model.n_ctx_train() as usize; - let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime); + let mmproj_overhead = if loaded.mtmd_ctx.is_some() { + settings.mmproj_size_bytes + } else { + 0 + }; + let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime, mmproj_overhead); let effective_ctx = effective_context_size( prompt_token_count, settings, @@ -261,6 +273,80 @@ pub(super) fn create_and_prefill_context<'model>( Ok(ctx) } +/// Tokenize text + images via mtmd and prefill the context. +/// +/// Returns the llama context, the number of prompt tokens consumed, +/// and the effective context size. +pub(super) fn create_and_prefill_multimodal<'model>( + loaded: &'model LoadedModel, + runtime: &InferenceRuntime, + prompt_text: &str, + images: &[ExtractedImage], + context_limit: usize, + settings: &ModelSettings, +) -> Result<(llama_cpp_2::context::LlamaContext<'model>, usize, usize), ProviderError> { + let mtmd_ctx = loaded.mtmd_ctx.as_ref().ok_or_else(|| { + ProviderError::ExecutionError( + "This model does not have vision support. Download the vision encoder from \ + Settings > Local Inference, or use a text-only message." + .to_string(), + ) + })?; + + let bitmaps: Vec = images + .iter() + .map(|img| { + MtmdBitmap::from_buffer(mtmd_ctx, &img.bytes) + .map_err(|e| ProviderError::ExecutionError(format!("Failed to decode image: {e}"))) + }) + .collect::>()?; + + let bitmap_refs: Vec<&MtmdBitmap> = bitmaps.iter().collect(); + + let input_text = MtmdInputText { + text: prompt_text.to_string(), + add_special: true, + parse_special: true, + }; + let chunks = mtmd_ctx.tokenize(input_text, &bitmap_refs).map_err(|e| { + ProviderError::ExecutionError(format!("Multimodal tokenization failed: {e}")) + })?; + + let prompt_token_count = chunks.total_tokens(); + + let n_ctx_train = loaded.model.n_ctx_train() as usize; + let mmproj_overhead = settings.mmproj_size_bytes; + let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime, mmproj_overhead); + let effective_ctx = effective_context_size( + prompt_token_count, + settings, + context_limit, + n_ctx_train, + memory_max_ctx, + ); + + let min_generation_headroom = 512; + if prompt_token_count + min_generation_headroom > effective_ctx { + return Err(ProviderError::ContextLengthExceeded(format!( + "Multimodal prompt ({prompt_token_count} tokens including images) exceeds \ + context limit ({effective_ctx} tokens)", + ))); + } + + let ctx_params = build_context_params(effective_ctx as u32, settings); + let llama_ctx = loaded + .model + .new_context(runtime.backend(), ctx_params) + .map_err(|e| ProviderError::ExecutionError(format!("Failed to create context: {e}")))?; + + let n_batch = llama_ctx.n_batch() as i32; + let _n_past = chunks + .eval_chunks(mtmd_ctx, &llama_ctx, 0, 0, n_batch, true) + .map_err(|e| ProviderError::ExecutionError(format!("Multimodal eval failed: {e}")))?; + + Ok((llama_ctx, prompt_token_count, effective_ctx)) +} + /// Action to take after processing a generated token piece. pub(super) enum TokenAction { Continue, diff --git a/crates/goose/src/providers/local_inference/inference_native_tools.rs b/crates/goose/src/providers/local_inference/inference_native_tools.rs index 15850346f8..e89c50c4a0 100644 --- a/crates/goose/src/providers/local_inference/inference_native_tools.rs +++ b/crates/goose/src/providers/local_inference/inference_native_tools.rs @@ -9,8 +9,9 @@ use uuid::Uuid; use super::finalize_usage; use super::inference_engine::{ - context_cap, create_and_prefill_context, estimate_max_context_for_memory, generation_loop, - validate_and_compute_context, GenerationContext, TokenAction, + context_cap, create_and_prefill_context, create_and_prefill_multimodal, + estimate_max_context_for_memory, generation_loop, validate_and_compute_context, + GenerationContext, TokenAction, }; pub(super) fn generate_with_native_tools( @@ -21,7 +22,13 @@ pub(super) fn generate_with_native_tools( ) -> Result<(), ProviderError> { let min_generation_headroom = 512; let n_ctx_train = ctx.loaded.model.n_ctx_train() as usize; - let memory_max_ctx = estimate_max_context_for_memory(&ctx.loaded.model, ctx.runtime); + let mmproj_overhead = if ctx.loaded.mtmd_ctx.is_some() { + ctx.settings.mmproj_size_bytes + } else { + 0 + }; + let memory_max_ctx = + estimate_max_context_for_memory(&ctx.loaded.model, ctx.runtime, mmproj_overhead); let cap = context_cap(ctx.settings, ctx.context_limit, n_ctx_train, memory_max_ctx); let token_budget = cap.saturating_sub(min_generation_headroom); @@ -61,6 +68,8 @@ pub(super) fn generate_with_native_tools( } }; + let estimated_image_tokens = ctx.images.len() * ctx.settings.image_token_estimate; + let template_result = match apply_template(full_tools_json) { Ok(r) => { let token_count = ctx @@ -69,7 +78,7 @@ pub(super) fn generate_with_native_tools( .str_to_token(&r.prompt, AddBos::Never) .map(|t| t.len()) .unwrap_or(0); - if token_count > token_budget { + if token_count + estimated_image_tokens > token_budget { apply_template(compact_tools).unwrap_or(r) } else { r @@ -85,26 +94,32 @@ pub(super) fn generate_with_native_tools( None, ); - let tokens = ctx - .loaded - .model - .str_to_token(&template_result.prompt, AddBos::Never) - .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; - - let (prompt_token_count, effective_ctx) = validate_and_compute_context( - ctx.loaded, - ctx.runtime, - tokens.len(), - ctx.context_limit, - ctx.settings, - )?; - let mut llama_ctx = create_and_prefill_context( - ctx.loaded, - ctx.runtime, - &tokens, - effective_ctx, - ctx.settings, - )?; + let (mut llama_ctx, prompt_token_count, effective_ctx) = if !ctx.images.is_empty() { + create_and_prefill_multimodal( + ctx.loaded, + ctx.runtime, + &template_result.prompt, + ctx.images, + ctx.context_limit, + ctx.settings, + )? + } else { + let tokens = ctx + .loaded + .model + .str_to_token(&template_result.prompt, AddBos::Never) + .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; + let (ptc, ectx) = validate_and_compute_context( + ctx.loaded, + ctx.runtime, + tokens.len(), + ctx.context_limit, + ctx.settings, + )?; + let lctx = + create_and_prefill_context(ctx.loaded, ctx.runtime, &tokens, ectx, ctx.settings)?; + (lctx, ptc, ectx) + }; let message_id = ctx.message_id; let tx = ctx.tx; diff --git a/crates/goose/src/providers/local_inference/local_model_registry.rs b/crates/goose/src/providers/local_inference/local_model_registry.rs index 1b5917d38c..d79727ee71 100644 --- a/crates/goose/src/providers/local_inference/local_model_registry.rs +++ b/crates/goose/src/providers/local_inference/local_model_registry.rs @@ -62,12 +62,27 @@ pub struct ModelSettings { pub use_jinja: bool, #[serde(default = "default_true")] pub enable_thinking: bool, + /// Whether this model architecture supports vision input. + /// Derived from the featured model table, not user-configurable. + #[serde(default)] + pub vision_capable: bool, + /// Estimated tokens per image for budget planning before mtmd tokenization. + /// The actual count is determined after tokenization via `chunks.total_tokens()`. + #[serde(default = "default_image_token_estimate")] + pub image_token_estimate: usize, + /// Size of the mmproj file in bytes, used for memory accounting. + #[serde(default)] + pub mmproj_size_bytes: u64, } fn default_true() -> bool { true } +fn default_image_token_estimate() -> usize { + 256 +} + fn default_repeat_penalty() -> f32 { 1.0 } @@ -94,41 +109,75 @@ impl Default for ModelSettings { native_tool_calling: false, use_jinja: false, enable_thinking: true, + vision_capable: false, + image_token_estimate: default_image_token_estimate(), + mmproj_size_bytes: 0, } } } +/// HuggingFace repo + filename for multimodal projection weights (vision encoder). +pub struct MmprojSpec { + pub repo: &'static str, + pub filename: &'static str, +} + +impl MmprojSpec { + /// Local path for this mmproj, namespaced by repo to avoid collisions + /// between different models that use the same filename. + pub fn local_path(&self) -> std::path::PathBuf { + let repo_name = self.repo.split('/').next_back().unwrap_or(self.repo); + Paths::in_data_dir("models") + .join(repo_name) + .join(self.filename) + } +} + pub struct FeaturedModel { /// HuggingFace spec in "author/repo-GGUF:quantization" format. pub spec: &'static str, /// Whether this model's GGUF template supports native tool calling via llama.cpp. pub native_tool_calling: bool, + /// Multimodal projection weights spec. None for text-only models. + pub mmproj: Option, } pub const FEATURED_MODELS: &[FeaturedModel] = &[ FeaturedModel { spec: "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", native_tool_calling: false, + mmproj: None, }, FeaturedModel { spec: "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", native_tool_calling: false, + mmproj: None, }, FeaturedModel { spec: "bartowski/Hermes-2-Pro-Mistral-7B-GGUF:Q4_K_M", native_tool_calling: false, + mmproj: None, }, FeaturedModel { spec: "bartowski/Mistral-Small-24B-Instruct-2501-GGUF:Q4_K_M", native_tool_calling: false, + mmproj: None, }, FeaturedModel { spec: "unsloth/gemma-4-E4B-it-GGUF:Q4_K_M", native_tool_calling: true, + mmproj: Some(MmprojSpec { + repo: "unsloth/gemma-4-E4B-it-GGUF", + filename: "mmproj-BF16.gguf", + }), }, FeaturedModel { spec: "unsloth/gemma-4-26B-A4B-it-GGUF:Q4_K_M", native_tool_calling: true, + mmproj: Some(MmprojSpec { + repo: "unsloth/gemma-4-26B-A4B-it-GGUF", + filename: "mmproj-BF16.gguf", + }), }, ]; @@ -144,10 +193,25 @@ pub fn default_settings_for_model(model_id: &str) -> ModelSettings { }); ModelSettings { native_tool_calling: featured.is_some_and(|m| m.native_tool_calling), + vision_capable: featured.is_some_and(|m| m.mmproj.is_some()), ..ModelSettings::default() } } +/// Look up the `MmprojSpec` for a featured model by its model ID. +pub fn featured_mmproj_spec(model_id: &str) -> Option<&'static MmprojSpec> { + use super::hf_models::parse_model_spec; + let model_repo = model_id.split(':').next().unwrap_or(model_id); + FEATURED_MODELS.iter().find_map(|m| { + if let Ok((repo_id, _quant)) = parse_model_spec(m.spec) { + if repo_id == model_repo { + return m.mmproj.as_ref(); + } + } + None + }) +} + /// Check if a model ID corresponds to a featured model. pub fn is_featured_model(model_id: &str) -> bool { use super::hf_models::parse_model_spec; @@ -181,9 +245,42 @@ pub struct LocalModelEntry { pub settings: ModelSettings, #[serde(default)] pub size_bytes: u64, + /// Local path to the multimodal projection GGUF (vision encoder). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mmproj_path: Option, + /// Download URL for the mmproj file. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mmproj_source_url: Option, + /// Size of the mmproj file in bytes. + #[serde(default)] + pub mmproj_size_bytes: u64, } impl LocalModelEntry { + /// Populate mmproj metadata and vision settings from the featured model + /// table if this model's repo has a known vision encoder. + pub fn enrich_with_featured_mmproj(&mut self) { + if let Some(mmproj) = featured_mmproj_spec(&self.id) { + let path = mmproj.local_path(); + if self.mmproj_path.as_ref() != Some(&path) { + self.mmproj_path = Some(path.clone()); + self.mmproj_source_url = Some(format!( + "https://huggingface.co/{}/resolve/main/{}", + mmproj.repo, mmproj.filename + )); + } + self.settings.vision_capable = true; + if self.mmproj_size_bytes == 0 || self.settings.mmproj_size_bytes == 0 { + if let Ok(meta) = std::fs::metadata(&path) { + self.mmproj_size_bytes = meta.len(); + self.settings.mmproj_size_bytes = meta.len(); + } + } + } + let defaults = default_settings_for_model(&self.id); + self.settings.native_tool_calling = defaults.native_tool_calling; + } + pub fn is_downloaded(&self) -> bool { self.local_path.exists() } @@ -219,6 +316,36 @@ impl LocalModelEntry { ModelDownloadStatus::NotDownloaded } + pub fn has_vision(&self) -> bool { + self.mmproj_path.as_ref().is_some_and(|p| p.exists()) + } + + pub fn mmproj_download_status(&self) -> ModelDownloadStatus { + if let Some(path) = &self.mmproj_path { + if path.exists() { + return ModelDownloadStatus::Downloaded; + } + } else { + return ModelDownloadStatus::NotDownloaded; + } + + let download_id = format!("{}-mmproj", self.id); + let manager = get_download_manager(); + if let Some(progress) = manager.get_progress(&download_id) { + return match progress.status { + DownloadStatus::Downloading => ModelDownloadStatus::Downloading { + progress_percent: progress.progress_percent, + bytes_downloaded: progress.bytes_downloaded, + total_bytes: progress.total_bytes, + speed_bps: progress.speed_bps.unwrap_or(0), + }, + _ => ModelDownloadStatus::NotDownloaded, + }; + } + + ModelDownloadStatus::NotDownloaded + } + pub fn file_size(&self) -> u64 { if self.size_bytes > 0 { return self.size_bytes; @@ -290,8 +417,9 @@ impl LocalModelRegistry { pub fn sync_with_featured(&mut self, featured_entries: Vec) { let mut changed = false; - for entry in featured_entries { + for mut entry in featured_entries { if !self.models.iter().any(|m| m.id == entry.id) { + entry.enrich_with_featured_mmproj(); self.models.push(entry); changed = true; } @@ -309,7 +437,8 @@ impl LocalModelRegistry { } } - pub fn add_model(&mut self, entry: LocalModelEntry) -> Result<()> { + pub fn add_model(&mut self, mut entry: LocalModelEntry) -> Result<()> { + entry.enrich_with_featured_mmproj(); if let Some(existing) = self.models.iter_mut().find(|m| m.id == entry.id) { *existing = entry; } else { @@ -348,6 +477,10 @@ impl LocalModelRegistry { pub fn list_models(&self) -> &[LocalModelEntry] { &self.models } + + pub fn list_models_mut(&mut self) -> &mut [LocalModelEntry] { + &mut self.models + } } /// Generate a unique ID for a model from its repo_id and quantization. diff --git a/crates/goose/src/providers/local_inference/multimodal.rs b/crates/goose/src/providers/local_inference/multimodal.rs new file mode 100644 index 0000000000..157e2faec0 --- /dev/null +++ b/crates/goose/src/providers/local_inference/multimodal.rs @@ -0,0 +1,336 @@ +use base64::prelude::*; +use serde_json::Value; + +use crate::conversation::message::{Message, MessageContent}; +use crate::providers::errors::ProviderError; + +#[derive(Debug)] +pub struct ExtractedImage { + pub bytes: Vec, +} + +#[derive(Debug)] +#[allow(dead_code)] +pub struct MultimodalMessages { + pub messages_json: String, + pub images: Vec, +} + +/// Walk the OpenAI-format messages JSON array. For each content part with +/// `type: "image_url"`, decode the base64 data URL, store the raw bytes, +/// and replace the part with `{"type": "text", "text": ""}`. +/// +/// Returns the modified JSON string and the extracted images in order. +#[allow(dead_code)] +pub fn extract_images_from_messages_json( + messages_json: &str, + marker: &str, +) -> Result { + let mut messages: Vec = serde_json::from_str(messages_json).map_err(|e| { + ProviderError::ExecutionError(format!("Failed to parse messages JSON: {e}")) + })?; + + let mut images = Vec::new(); + + for msg in messages.iter_mut() { + let Some(content) = msg.get_mut("content").and_then(|c| c.as_array_mut()) else { + continue; + }; + + for part in content.iter_mut() { + if part.get("type").and_then(|t| t.as_str()) != Some("image_url") { + continue; + } + + let url = part + .get("image_url") + .and_then(|obj| obj.get("url")) + .and_then(|u| u.as_str()) + .unwrap_or_default(); + + if url.starts_with("http://") || url.starts_with("https://") { + return Err(ProviderError::ExecutionError( + "Remote image URLs are not supported with local inference. \ + Please attach the image directly." + .to_string(), + )); + } + + let base64_data = url.split_once(',').map_or(url, |(_, data)| data); + + let bytes = BASE64_STANDARD.decode(base64_data).map_err(|e| { + ProviderError::ExecutionError(format!("Failed to decode base64 image: {e}")) + })?; + + images.push(ExtractedImage { bytes }); + + *part = serde_json::json!({ + "type": "text", + "text": marker, + }); + } + } + + let messages_json = serde_json::to_string(&messages) + .map_err(|e| ProviderError::ExecutionError(format!("Failed to serialize messages: {e}")))?; + + Ok(MultimodalMessages { + messages_json, + images, + }) +} + +/// Scan messages for `MessageContent::Image` entries. Return the extracted image +/// bytes and a new message list with images replaced by text marker placeholders. +pub fn extract_images_from_messages( + messages: &[Message], + marker: &str, +) -> (Vec, Vec) { + let mut images = Vec::new(); + let mut new_messages = Vec::with_capacity(messages.len()); + + for msg in messages { + let mut new_content = Vec::with_capacity(msg.content.len()); + for content in &msg.content { + match content { + MessageContent::Image(img) => { + if let Ok(bytes) = BASE64_STANDARD.decode(&img.data) { + images.push(ExtractedImage { bytes }); + new_content.push(MessageContent::text(marker)); + } else { + new_content.push(MessageContent::text( + "[Image attached — failed to decode image data]", + )); + } + } + other => new_content.push(other.clone()), + } + } + new_messages.push(Message { + role: msg.role.clone(), + content: new_content, + ..msg.clone() + }); + } + + (images, new_messages) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn make_test_messages_json(parts: Vec) -> String { + serde_json::to_string(&vec![json!({ + "role": "user", + "content": parts, + })]) + .unwrap() + } + + fn tiny_png_base64() -> String { + // 1x1 red PNG + let bytes: &[u8] = &[ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, + 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00, + 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08, + 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x36, 0x28, 0x19, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, + ]; + BASE64_STANDARD.encode(bytes) + } + + #[test] + fn test_extract_images_replaces_image_url_with_marker() { + let b64 = tiny_png_base64(); + let json = make_test_messages_json(vec![json!({ + "type": "image_url", + "image_url": {"url": format!("data:image/png;base64,{b64}")} + })]); + + let result = extract_images_from_messages_json(&json, "<__media__>").unwrap(); + assert_eq!(result.images.len(), 1); + assert!(!result.images[0].bytes.is_empty()); + + let parsed: Vec = serde_json::from_str(&result.messages_json).unwrap(); + let content = parsed[0]["content"].as_array().unwrap(); + assert_eq!(content[0]["type"], "text"); + assert_eq!(content[0]["text"], "<__media__>"); + } + + #[test] + fn test_extract_images_preserves_text_parts() { + let json = make_test_messages_json(vec![json!({ + "type": "text", + "text": "Hello world" + })]); + + let result = extract_images_from_messages_json(&json, "<__media__>").unwrap(); + assert!(result.images.is_empty()); + + let parsed: Vec = serde_json::from_str(&result.messages_json).unwrap(); + let content = parsed[0]["content"].as_array().unwrap(); + assert_eq!(content[0]["type"], "text"); + assert_eq!(content[0]["text"], "Hello world"); + } + + #[test] + fn test_extract_images_multiple_images() { + let b64 = tiny_png_base64(); + let json = make_test_messages_json(vec![ + json!({"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}), + json!({"type": "text", "text": "describe both"}), + json!({"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}), + ]); + + let result = extract_images_from_messages_json(&json, "<__media__>").unwrap(); + assert_eq!(result.images.len(), 2); + + let parsed: Vec = serde_json::from_str(&result.messages_json).unwrap(); + let content = parsed[0]["content"].as_array().unwrap(); + assert_eq!(content[0]["type"], "text"); + assert_eq!(content[0]["text"], "<__media__>"); + assert_eq!(content[1]["type"], "text"); + assert_eq!(content[1]["text"], "describe both"); + assert_eq!(content[2]["type"], "text"); + assert_eq!(content[2]["text"], "<__media__>"); + } + + #[test] + fn test_extract_images_no_images() { + let json = make_test_messages_json(vec![json!({ + "type": "text", + "text": "just text" + })]); + + let result = extract_images_from_messages_json(&json, "<__media__>").unwrap(); + assert!(result.images.is_empty()); + // JSON should be equivalent + let original: Vec = serde_json::from_str(&json).unwrap(); + let result_parsed: Vec = serde_json::from_str(&result.messages_json).unwrap(); + assert_eq!(original, result_parsed); + } + + #[test] + fn test_extract_images_http_url_rejected() { + let json = make_test_messages_json(vec![json!({ + "type": "image_url", + "image_url": {"url": "https://example.com/image.png"} + })]); + + let result = extract_images_from_messages_json(&json, "<__media__>"); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("Remote image URLs are not supported")); + } + + #[test] + fn test_extract_images_mixed_content() { + let b64 = tiny_png_base64(); + // Two messages: first with text+image, second with just text + let json = serde_json::to_string(&vec![ + json!({ + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}, + ] + }), + json!({ + "role": "assistant", + "content": [ + {"type": "text", "text": "It looks like a red pixel."}, + ] + }), + ]) + .unwrap(); + + let result = extract_images_from_messages_json(&json, "<__media__>").unwrap(); + assert_eq!(result.images.len(), 1); + + let parsed: Vec = serde_json::from_str(&result.messages_json).unwrap(); + // First message: text preserved, image replaced + let content0 = parsed[0]["content"].as_array().unwrap(); + assert_eq!(content0[0]["text"], "What is this?"); + assert_eq!(content0[1]["text"], "<__media__>"); + // Second message unchanged + let content1 = parsed[1]["content"].as_array().unwrap(); + assert_eq!(content1[0]["text"], "It looks like a red pixel."); + } + + // --- Tests for extract_images_from_messages (Message-based) --- + + #[test] + fn test_messages_extract_replaces_image_with_marker() { + let b64 = tiny_png_base64(); + let messages = vec![Message::user().with_image(b64, "image/png")]; + + let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>"); + assert_eq!(images.len(), 1); + assert!(!images[0].bytes.is_empty()); + assert_eq!(new_msgs.len(), 1); + assert_eq!(new_msgs[0].as_concat_text(), "<__media__>"); + } + + #[test] + fn test_messages_extract_preserves_text() { + let messages = vec![Message::user().with_text("Hello world")]; + + let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>"); + assert!(images.is_empty()); + assert_eq!(new_msgs[0].as_concat_text(), "Hello world"); + } + + #[test] + fn test_messages_extract_multiple_images() { + let b64 = tiny_png_base64(); + let messages = vec![Message::user() + .with_image(b64.clone(), "image/png") + .with_text("describe both") + .with_image(b64, "image/png")]; + + let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>"); + assert_eq!(images.len(), 2); + assert_eq!(new_msgs[0].content.len(), 3); + assert_eq!( + new_msgs[0].as_concat_text(), + "<__media__>\ndescribe both\n<__media__>" + ); + } + + #[test] + fn test_messages_extract_no_images() { + let messages = vec![Message::user().with_text("just text")]; + + let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>"); + assert!(images.is_empty()); + assert_eq!(new_msgs[0].as_concat_text(), "just text"); + } + + #[test] + fn test_messages_extract_invalid_base64() { + let messages = vec![Message::user().with_image("not-valid-base64!!!", "image/png")]; + + let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>"); + assert!(images.is_empty()); + assert!(new_msgs[0].as_concat_text().contains("failed to decode")); + } + + #[test] + fn test_messages_extract_mixed_content() { + let b64 = tiny_png_base64(); + let messages = vec![ + Message::user() + .with_text("What is this?") + .with_image(b64, "image/png"), + Message::assistant().with_text("It looks like a red pixel."), + ]; + + let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>"); + assert_eq!(images.len(), 1); + assert_eq!(new_msgs.len(), 2); + assert_eq!(new_msgs[0].as_concat_text(), "What is this?\n<__media__>"); + assert_eq!(new_msgs[1].as_concat_text(), "It looks like a red pixel."); + } +} diff --git a/crates/goose/tests/local_inference_integration.rs b/crates/goose/tests/local_inference_integration.rs index abf2af661e..971a767092 100644 --- a/crates/goose/tests/local_inference_integration.rs +++ b/crates/goose/tests/local_inference_integration.rs @@ -9,7 +9,11 @@ //! //! Run with a specific model: //! TEST_MODEL="bartowski/Qwen_Qwen3-32B-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration -- --ignored +//! +//! Run vision tests (requires a vision-capable model like gemma-4): +//! TEST_VISION_MODEL="unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration test_local_inference_vision -- --ignored +use base64::prelude::*; use futures::StreamExt; use goose::conversation::message::Message; use goose::model::ModelConfig; @@ -93,3 +97,130 @@ async fn test_local_inference_large_prompt() { text.len() ); } + +fn vision_test_model() -> Option { + std::env::var("TEST_VISION_MODEL").ok() +} + +/// Generate a small solid-colour 2x2 red PNG as raw bytes. +fn tiny_red_png() -> Vec { + vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature + 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, // 1x1 + 0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, // RGB, 8-bit + 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, // IDAT chunk + 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x36, 0x28, 0x19, + 0x00, // compressed pixel data + 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, // IEND + ] +} + +/// Test that a vision-capable local model can process a message with an embedded image +/// and produce a text response without crashing. +/// +/// Requires TEST_VISION_MODEL to be set to a downloaded vision model. +/// Example: +/// TEST_VISION_MODEL="unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" \ +/// cargo test -p goose --test local_inference_integration test_local_inference_vision -- --ignored +#[tokio::test] +#[ignore] +async fn test_local_inference_vision_produces_output() { + let model_id = match vision_test_model() { + Some(id) => id, + None => { + eprintln!( + "Skipping vision test: TEST_VISION_MODEL not set. \ + Set it to a vision-capable model like unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" + ); + return; + } + }; + + let model_config = ModelConfig::new(&model_id).expect("valid model config"); + let provider = create("local", model_config.clone(), Vec::new()) + .await + .expect("provider creation should succeed"); + + let image_bytes = tiny_red_png(); + let image_b64 = BASE64_STANDARD.encode(&image_bytes); + + let system = "You are a helpful assistant. Describe images briefly."; + let messages = vec![Message::user() + .with_text("What color is this image?") + .with_image(image_b64, "image/png")]; + + let mut stream = provider + .stream(&model_config, "test-vision-session", system, &messages, &[]) + .await + .expect("stream should start for vision input"); + + let mut got_text = false; + let mut collected_text = String::new(); + + while let Some(result) = stream.next().await { + let (msg, _usage) = result.expect("stream item should be Ok"); + if let Some(m) = msg { + got_text = true; + collected_text.push_str(&m.as_concat_text()); + } + } + + assert!( + got_text, + "vision stream should produce at least one text message" + ); + assert!( + !collected_text.is_empty(), + "vision response should contain text" + ); + println!("Vision response: {collected_text}"); +} + +/// Test that sending an image to a text-only model produces a clear error +/// rather than crashing. +#[tokio::test] +#[ignore] +async fn test_local_inference_vision_text_only_model_graceful() { + let model_config = ModelConfig::new(&test_model()).expect("valid model config"); + let provider = create("local", model_config.clone(), Vec::new()) + .await + .expect("provider creation should succeed"); + + let image_bytes = tiny_red_png(); + let image_b64 = BASE64_STANDARD.encode(&image_bytes); + + let system = "You are a helpful assistant."; + let messages = vec![Message::user() + .with_text("What is this?") + .with_image(image_b64, "image/png")]; + + let mut stream = provider + .stream(&model_config, "test-session", system, &messages, &[]) + .await + .expect("stream should start"); + + // The stream should either produce a response with the image stripped + // (placeholder text) or produce an error — but it must not crash. + let mut completed = false; + while let Some(result) = stream.next().await { + match result { + Ok(_) => completed = true, + Err(e) => { + // An error about missing vision support is acceptable + let err_msg = e.to_string(); + assert!( + err_msg.contains("vision") || err_msg.contains("image"), + "error should mention vision/image support, got: {err_msg}" + ); + completed = true; + break; + } + } + } + + assert!( + completed, + "stream should complete without crashing when images sent to text-only model" + ); +} diff --git a/goose-self-test.yaml b/goose-self-test.yaml index 8744b65aec..59e922d0ee 100644 --- a/goose-self-test.yaml +++ b/goose-self-test.yaml @@ -317,6 +317,31 @@ prompt: | Log results to: {{ workspace_dir }}/phase3_delegation.md {% endif %} + {% if test_phases == "all" or "vision" in test_phases %} + ## 📷 PHASE 3B: Local Inference Vision Testing + + **Prerequisites**: A vision-capable local model must be downloaded (e.g., gemma-4-E4B). + Skip this phase if no local vision model is available. + + ### Vision Smoke Test + 1. Create a small test image: + ``` + python3 -c "import struct, zlib; raw=b'\x00\xff\x00\x00'; d=zlib.compress(raw); ihdr=b'\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00'; print('Created test.png')" + ``` + Or simply create a 1-pixel PNG test image using available tools. + 2. Verify the test image file exists and is valid. + 3. Send a message to the local vision model referencing the test image. + 4. Verify the model responds with text (not an error or crash). + 5. Verify the response acknowledges the image content. + + ### Vision Error Handling Test + 1. If a text-only local model is available, send it a message with an image attached. + 2. Verify it responds gracefully (either with a placeholder message or a clear error), + not with a crash or FFI error. + + Log results to: {{ workspace_dir }}/phase3b_vision.md + {% endif %} + {% if test_phases == "all" or "advanced" in test_phases %} ## 🔬 PHASE 4: Advanced Testing diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 7d9e91f831..2bc63d4db8 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -5724,7 +5724,8 @@ "size_bytes", "status", "recommended", - "settings" + "settings", + "vision_capable" ], "properties": { "filename": { @@ -5733,6 +5734,14 @@ "id": { "type": "string" }, + "mmproj_status": { + "allOf": [ + { + "$ref": "#/components/schemas/ModelDownloadStatus" + } + ], + "nullable": true + }, "quantization": { "type": "string" }, @@ -5752,6 +5761,9 @@ }, "status": { "$ref": "#/components/schemas/ModelDownloadStatus" + }, + "vision_capable": { + "type": "boolean" } } }, @@ -6497,11 +6509,22 @@ "type": "number", "format": "float" }, + "image_token_estimate": { + "type": "integer", + "description": "Estimated tokens per image for budget planning before mtmd tokenization.\nThe actual count is determined after tokenization via `chunks.total_tokens()`.", + "minimum": 0 + }, "max_output_tokens": { "type": "integer", "nullable": true, "minimum": 0 }, + "mmproj_size_bytes": { + "type": "integer", + "format": "int64", + "description": "Size of the mmproj file in bytes, used for memory accounting.", + "minimum": 0 + }, "n_batch": { "type": "integer", "format": "int32", @@ -6542,6 +6565,10 @@ }, "use_mlock": { "type": "boolean" + }, + "vision_capable": { + "type": "boolean", + "description": "Whether this model architecture supports vision input.\nDerived from the featured model table, not user-configurable." } } }, diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 703e082282..c0c4c30be8 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -612,12 +612,14 @@ export type LoadedProvider = { export type LocalModelResponse = { filename: string; id: string; + mmproj_status?: ModelDownloadStatus | null; quantization: string; recommended: boolean; repo_id: string; settings: ModelSettings; size_bytes: number; status: ModelDownloadStatus; + vision_capable: boolean; }; /** @@ -821,7 +823,16 @@ export type ModelSettings = { enable_thinking?: boolean; flash_attention?: boolean | null; frequency_penalty?: number; + /** + * Estimated tokens per image for budget planning before mtmd tokenization. + * The actual count is determined after tokenization via `chunks.total_tokens()`. + */ + image_token_estimate?: number; max_output_tokens?: number | null; + /** + * Size of the mmproj file in bytes, used for memory accounting. + */ + mmproj_size_bytes?: number; n_batch?: number | null; n_gpu_layers?: number | null; n_threads?: number | null; @@ -832,6 +843,11 @@ export type ModelSettings = { sampling?: SamplingConfig; use_jinja?: boolean; use_mlock?: boolean; + /** + * Whether this model architecture supports vision input. + * Derived from the featured model table, not user-configurable. + */ + vision_capable?: boolean; }; export type ModelTemplate = { diff --git a/ui/desktop/src/components/settings/localInference/LocalInferenceSettings.tsx b/ui/desktop/src/components/settings/localInference/LocalInferenceSettings.tsx index e29708cf14..5bbcd3a0df 100644 --- a/ui/desktop/src/components/settings/localInference/LocalInferenceSettings.tsx +++ b/ui/desktop/src/components/settings/localInference/LocalInferenceSettings.tsx @@ -1,5 +1,5 @@ import { useState, useEffect, useCallback, useRef } from 'react'; -import { Download, Trash2, X, ChevronDown, ChevronUp, Settings2 } from 'lucide-react'; +import { Download, Trash2, X, ChevronDown, ChevronUp, Settings2, Eye } from 'lucide-react'; import { Button } from '../../ui/button'; import { useModelAndProvider } from '../../ModelAndProviderContext'; import { defineMessages, useIntl } from '../../../i18n'; @@ -83,8 +83,57 @@ const i18n = defineMessages({ id: 'localInferenceSettings.modelSettingsTitle', defaultMessage: 'Model settings', }, + vision: { + id: 'localInferenceSettings.vision', + defaultMessage: 'Vision', + }, + visionEncoderDownloading: { + id: 'localInferenceSettings.visionEncoderDownloading', + defaultMessage: 'Vision encoder downloading…', + }, + visionEncoderNotDownloaded: { + id: 'localInferenceSettings.visionEncoderNotDownloaded', + defaultMessage: 'Vision encoder not downloaded', + }, }); +const VisionBadge = ({ model, intl }: { model: LocalModelResponse; intl: ReturnType }) => { + if (!model.vision_capable) return null; + + const mmproj = model.mmproj_status; + const isDownloaded = mmproj?.state === 'Downloaded'; + const isDownloading = mmproj?.state === 'Downloading'; + + if (isDownloaded) { + return ( + + + {intl.formatMessage(i18n.vision)} + + ); + } + + if (isDownloading) { + const percent = mmproj && 'progress_percent' in mmproj + ? Math.round(mmproj.progress_percent) + : null; + return ( + + + {intl.formatMessage(i18n.visionEncoderDownloading)} + {percent != null && ` ${percent}%`} + + ); + } + + return ( + + + {intl.formatMessage(i18n.vision)} + + ); +}; + const formatBytes = (bytes: number): string => { if (bytes < 1024) return `${bytes}B`; if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)}KB`; @@ -128,6 +177,19 @@ export const LocalInferenceSettings = () => { // eslint-disable-next-line react-hooks/exhaustive-deps }, []); + // Poll model list while any vision encoder is downloading + useEffect(() => { + const hasDownloadingMmproj = models.some( + (m) => m.vision_capable && m.mmproj_status?.state === 'Downloading' + ); + if (!hasDownloadingMmproj) return; + + const interval = setInterval(() => { + loadModels(); + }, 2000); + return () => clearInterval(interval); + }, [models, loadModels]); + const selectModel = async (modelId: string) => { try { await setConfigProvider({ @@ -365,6 +427,7 @@ export const LocalInferenceSettings = () => { {intl.formatMessage(i18n.recommended)} )} +