Add vision/image support for local inference models (#8442)
Some checks failed
Canary / Prepare Version (push) Waiting to run
Canary / build-cli (push) Blocked by required conditions
Canary / Upload Install Script (push) Blocked by required conditions
Canary / bundle-desktop (push) Blocked by required conditions
Canary / bundle-desktop-intel (push) Blocked by required conditions
Canary / bundle-desktop-linux (push) Blocked by required conditions
Canary / bundle-desktop-windows (push) Blocked by required conditions
Canary / Release (push) Blocked by required conditions
Unused Dependencies / machete (push) Waiting to run
CI / changes (push) Waiting to run
CI / Check Rust Code Format (push) Blocked by required conditions
CI / Build and Test Rust Project (push) Blocked by required conditions
CI / Build Rust Project on Windows (push) Waiting to run
CI / Lint Rust Code (push) Blocked by required conditions
CI / Check Generated Schemas are Up-to-Date (push) Blocked by required conditions
CI / Test and Lint Electron Desktop App (push) Blocked by required conditions
Live Provider Tests / check-fork (push) Waiting to run
Live Provider Tests / changes (push) Blocked by required conditions
Live Provider Tests / Build Binary (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (Code Execution) (push) Blocked by required conditions
Live Provider Tests / Compaction Tests (push) Blocked by required conditions
Live Provider Tests / goose server HTTP integration tests (push) Blocked by required conditions
Publish Docker Image / docker (push) Waiting to run
Scorecard supply-chain security / Scorecard analysis (push) Waiting to run
Cargo Deny / deny (push) Has been cancelled

Signed-off-by: jh-block <jhugo@block.xyz>
This commit is contained in:
jh-block 2026-04-13 10:17:04 +02:00 committed by GitHub
parent 5fa2a8b821
commit de317d5445
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1181 additions and 88 deletions

View file

@ -1604,6 +1604,9 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()>
source_url: file.download_url.clone(),
settings: Default::default(),
size_bytes: file.size_bytes,
mmproj_path: None,
mmproj_source_url: None,
mmproj_size_bytes: 0,
};
{

View file

@ -1,3 +1,5 @@
use std::path::PathBuf;
use crate::routes::errors::ErrorResponse;
use crate::state::AppState;
use axum::{
@ -13,9 +15,9 @@ use goose::providers::local_inference::{
available_inference_memory_bytes,
hf_models::{resolve_model_spec, HfGgufFile},
local_model_registry::{
default_settings_for_model, get_registry, is_featured_model, model_id_from_repo,
LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus, ModelSettings,
FEATURED_MODELS,
default_settings_for_model, featured_mmproj_spec, get_registry, is_featured_model,
model_id_from_repo, LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus,
ModelSettings, FEATURED_MODELS,
},
recommend_local_model,
};
@ -47,10 +49,14 @@ pub struct LocalModelResponse {
pub status: ModelDownloadStatus,
pub recommended: bool,
pub settings: ModelSettings,
pub vision_capable: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub mmproj_status: Option<ModelDownloadStatus>,
}
async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
let mut entries_to_add = Vec::new();
let mut mmproj_downloads_needed: Vec<(String, String, PathBuf)> = Vec::new();
for featured in FEATURED_MODELS {
let (repo_id, quantization) = match hf_models::parse_model_spec(featured.spec) {
@ -64,9 +70,28 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
let registry = get_registry()
.lock()
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
if registry.has_model(&model_id) {
if let Some(existing) = registry.get_model(&model_id) {
let needs_backfill = existing.mmproj_path.is_none() && featured.mmproj.is_some();
let needs_download = existing.is_downloaded()
&& featured.mmproj.is_some()
&& !existing.mmproj_path.as_ref().is_some_and(|p| p.exists());
if needs_download {
if let Some(mmproj) = featured.mmproj.as_ref() {
let path = mmproj.local_path();
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
mmproj.repo, mmproj.filename
);
mmproj_downloads_needed.push((model_id.clone(), url, path));
}
}
if !needs_backfill {
continue;
}
// Fall through to build the entry for sync_with_featured backfill
}
}
let hf_file = match resolve_model_spec(featured.spec).await {
@ -91,6 +116,8 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
let local_path = Paths::in_data_dir("models").join(&hf_file.filename);
// enrich_with_featured_mmproj is called by sync_with_featured/add_model,
// so we don't need to populate mmproj fields here.
entries_to_add.push(LocalModelEntry {
id: model_id.clone(),
repo_id,
@ -100,16 +127,60 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
source_url: hf_file.download_url,
settings: default_settings_for_model(&model_id),
size_bytes: hf_file.size_bytes,
mmproj_path: None,
mmproj_source_url: None,
mmproj_size_bytes: 0,
});
}
if !entries_to_add.is_empty() {
{
let mut registry = get_registry()
.lock()
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
if !entries_to_add.is_empty() {
registry.sync_with_featured(entries_to_add);
}
// Backfill mmproj data for all registry models and collect any
// needed mmproj downloads for models already on disk.
for model in registry.list_models_mut() {
model.enrich_with_featured_mmproj();
if model.is_downloaded() {
if let Some(mmproj) = featured_mmproj_spec(&model.id) {
let path = mmproj.local_path();
if !path.exists() {
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
mmproj.repo, mmproj.filename
);
mmproj_downloads_needed.push((model.id.clone(), url, path));
}
}
}
}
let _ = registry.save();
}
// Auto-download mmproj files for models that are already downloaded.
// Deduplicate by path since multiple quants share one mmproj file.
let dm = get_download_manager();
let mut started_paths = std::collections::HashSet::new();
for (model_id, url, path) in mmproj_downloads_needed {
if !path.exists() && started_paths.insert(path.clone()) {
let download_id = format!("{}-mmproj", model_id);
let dominated_by_active = dm
.get_progress(&download_id)
.is_some_and(|p| p.status == goose::download_manager::DownloadStatus::Downloading);
if !dominated_by_active {
tracing::info!(model_id = %model_id, "Auto-downloading vision encoder for existing model");
if let Err(e) = dm.download_model(download_id, url, path, None).await {
tracing::warn!(model_id = %model_id, error = %e, "Failed to start mmproj download");
}
}
}
}
Ok(())
}
@ -154,6 +225,28 @@ pub async fn list_local_models(
let size_bytes = entry.file_size();
let vision_capable = entry.settings.vision_capable;
let mmproj_status = if vision_capable {
let ms = entry.mmproj_download_status();
Some(match ms {
RegistryDownloadStatus::NotDownloaded => ModelDownloadStatus::NotDownloaded,
RegistryDownloadStatus::Downloading {
progress_percent,
bytes_downloaded,
total_bytes,
speed_bps,
} => ModelDownloadStatus::Downloading {
progress_percent,
bytes_downloaded,
total_bytes,
speed_bps: Some(speed_bps),
},
RegistryDownloadStatus::Downloaded => ModelDownloadStatus::Downloaded,
})
} else {
None
};
models.push(LocalModelResponse {
id: entry.id.clone(),
repo_id: entry.repo_id.clone(),
@ -163,6 +256,8 @@ pub async fn list_local_models(
status,
recommended: recommended_id == entry.id,
settings: entry.settings.clone(),
vision_capable,
mmproj_status,
});
}
@ -276,16 +371,26 @@ pub async fn download_hf_model(
source_url: download_url.clone(),
settings: default_settings_for_model(&model_id),
size_bytes: hf_file.size_bytes,
mmproj_path: None,
mmproj_source_url: None,
mmproj_size_bytes: 0,
};
{
// add_model enriches the entry with mmproj metadata from the featured table
let mmproj_path = {
let mut registry = get_registry()
.lock()
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
registry
.add_model(entry)
.map_err(|e| ErrorResponse::internal(format!("{}", e)))?;
}
registry.get_model(&model_id).and_then(|e| {
e.mmproj_path
.as_ref()
.zip(e.mmproj_source_url.as_ref())
.map(|(p, u)| (p.clone(), u.clone()))
})
};
let dm = get_download_manager();
dm.download_model(
@ -297,6 +402,19 @@ pub async fn download_hf_model(
.await
.map_err(|e| ErrorResponse::internal(format!("Download failed: {}", e)))?;
if let Some((mmproj_path, mmproj_url)) = mmproj_path {
if !mmproj_path.exists() {
dm.download_model(
format!("{}-mmproj", model_id),
mmproj_url,
mmproj_path,
None,
)
.await
.map_err(|e| ErrorResponse::internal(format!("mmproj download failed: {}", e)))?;
}
}
Ok((StatusCode::ACCEPTED, Json(model_id)))
}
@ -338,6 +456,7 @@ pub async fn cancel_local_model_download(
manager
.cancel_download(&format!("{}-model", model_id))
.map_err(|e| ErrorResponse::internal(format!("{}", e)))?;
let _ = manager.cancel_download(&format!("{}-mmproj", model_id));
Ok(StatusCode::OK)
}
@ -351,14 +470,22 @@ pub async fn cancel_local_model_download(
)
)]
pub async fn delete_local_model(Path(model_id): Path<String>) -> Result<StatusCode, ErrorResponse> {
let local_path = {
let (local_path, mmproj_path, other_uses_mmproj) = {
let registry = get_registry()
.lock()
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
let entry = registry
.get_model(&model_id)
.ok_or_else(|| ErrorResponse::not_found("Model not found"))?;
entry.local_path.clone()
let lp = entry.local_path.clone();
let mp = entry.mmproj_path.clone();
// Check if another downloaded model shares this mmproj file
let shared = mp.as_ref().is_some_and(|target| {
registry.list_models().iter().any(|m| {
m.id != model_id && m.is_downloaded() && m.mmproj_path.as_ref() == Some(target)
})
});
(lp, mp, shared)
};
if local_path.exists() {
@ -367,6 +494,14 @@ pub async fn delete_local_model(Path(model_id): Path<String>) -> Result<StatusCo
.map_err(|e| ErrorResponse::internal(format!("Failed to delete: {}", e)))?;
}
if !other_uses_mmproj {
if let Some(mmproj) = mmproj_path {
if mmproj.exists() {
let _ = tokio::fs::remove_file(&mmproj).await;
}
}
}
// Only remove non-featured models from registry (featured ones stay as placeholders)
if !is_featured_model(&model_id) {
let mut registry = get_registry()

View file

@ -179,7 +179,7 @@ tree-sitter-typescript = { workspace = true }
which = { workspace = true }
pctx_code_mode = { version = "^0.3.0", optional = true }
pulldown-cmark = "0.13.0"
llama-cpp-2 = { version = "0.1.143", features = ["sampler"], optional = true }
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "mtmd"], optional = true }
encoding_rs = "0.8.35"
pastey = "0.2.1"
shell-words = { workspace = true }
@ -197,7 +197,7 @@ keyring = { version = "3.6.2", features = ["windows-native"] }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { version = "0.9", default-features = false, features = ["metal"], optional = true }
candle-nn = { version = "0.9", default-features = false, features = ["metal"], optional = true }
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "metal"], optional = true }
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "metal", "mtmd"], optional = true }
keyring = { version = "3.6.2", features = ["apple-native"] }
[target.'cfg(target_os = "linux")'.dependencies]

View file

@ -3,6 +3,7 @@ mod inference_emulated_tools;
mod inference_engine;
mod inference_native_tools;
pub mod local_model_registry;
pub(crate) mod multimodal;
mod tool_parsing;
use inference_emulated_tools::{
@ -30,6 +31,7 @@ use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel};
use llama_cpp_2::{list_llama_ggml_backend_devices, LlamaBackendDeviceType, LogOptions};
use multimodal::ExtractedImage;
use rmcp::model::{Role, Tool};
use serde_json::{json, Value};
use std::collections::HashMap;
@ -114,14 +116,15 @@ const DEFAULT_MODEL: &str = "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M";
pub const LOCAL_LLM_MODEL_CONFIG_KEY: &str = "LOCAL_LLM_MODEL";
/// Resolve model path, context limit, and settings for a model ID from the registry.
pub fn resolve_model_path(
model_id: &str,
) -> Option<(
PathBuf,
usize,
crate::providers::local_inference::local_model_registry::ModelSettings,
)> {
pub struct ResolvedModelPaths {
pub model_path: PathBuf,
pub context_limit: usize,
pub settings: crate::providers::local_inference::local_model_registry::ModelSettings,
pub mmproj_path: Option<PathBuf>,
}
/// Resolve model path, context limit, settings, and mmproj path for a model ID from the registry.
pub fn resolve_model_path(model_id: &str) -> Option<ResolvedModelPaths> {
use crate::providers::local_inference::local_model_registry::{
default_settings_for_model, get_registry,
};
@ -135,7 +138,15 @@ pub fn resolve_model_path(
// recognized (or with a different quantization) still get the right behavior.
let defaults = default_settings_for_model(model_id);
settings.native_tool_calling = defaults.native_tool_calling;
return Some((entry.local_path.clone(), ctx, settings));
settings.vision_capable = defaults.vision_capable;
settings.mmproj_size_bytes = entry.mmproj_size_bytes;
let mmproj_path = entry.mmproj_path.as_ref().filter(|p| p.exists()).cloned();
return Some(ResolvedModelPaths {
model_path: entry.local_path.clone(),
context_limit: ctx,
settings,
mmproj_path,
});
}
}
@ -208,9 +219,33 @@ fn build_openai_messages_json(system: &str, messages: &[Message]) -> String {
let mut arr: Vec<Value> = vec![json!({"role": "system", "content": system})];
arr.extend(format_messages(messages, &ImageFormat::OpenAi));
strip_image_parts_from_messages(&mut arr);
serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string())
}
/// Remove `image_url` content parts from OpenAI-format messages JSON, replacing
/// each with a text note. This prevents an FFI crash in llama.cpp which does not
/// accept `image_url` content-part types.
fn strip_image_parts_from_messages(messages: &mut [Value]) {
let mut stripped = false;
for msg in messages.iter_mut() {
if let Some(content) = msg.get_mut("content").and_then(|c| c.as_array_mut()) {
for part in content.iter_mut() {
if part.get("type").and_then(|t| t.as_str()) == Some("image_url") {
*part = json!({
"type": "text",
"text": "[Image attached — image input is not supported with the currently selected model]"
});
stripped = true;
}
}
}
}
if stripped {
tracing::warn!("Stripped image content parts from messages — vision encoder not available for this model");
}
}
/// Convert a message into plain text for the emulator path's chat history.
///
/// This is the emulator-path counterpart of [`format_messages`] used by the native
@ -269,6 +304,12 @@ fn extract_text_content(msg: &Message) -> String {
parts.push(format!("Command error: {}", e));
}
},
MessageContent::Image(_) => {
parts.push(
"[Image attached — image input is not supported with the currently selected model]"
.to_string(),
);
}
_ => {}
}
}
@ -331,8 +372,9 @@ impl LocalInferenceProvider {
model_id: &str,
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
) -> Result<LoadedModel, ProviderError> {
let (model_path, _context_limit, _) = resolve_model_path(model_id)
let resolved = resolve_model_path(model_id)
.ok_or_else(|| ProviderError::ExecutionError(format!("Unknown model: {}", model_id)))?;
let model_path = resolved.model_path;
if !model_path.exists() {
return Err(ProviderError::ExecutionError(format!(
@ -368,9 +410,49 @@ impl LocalInferenceProvider {
}
};
let mtmd_ctx = Self::init_mtmd_context(&model, &resolved.mmproj_path, settings);
tracing::info!(model_id = model_id, "Model loaded successfully");
Ok(LoadedModel { model, template })
Ok(LoadedModel {
model,
template,
mtmd_ctx,
})
}
fn init_mtmd_context(
model: &LlamaModel,
mmproj_path: &Option<PathBuf>,
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
) -> Option<llama_cpp_2::mtmd::MtmdContext> {
use llama_cpp_2::mtmd::{MtmdContext, MtmdContextParams};
let mmproj_path = mmproj_path.as_ref().filter(|p| p.exists())?;
let params = MtmdContextParams {
use_gpu: true,
n_threads: settings
.n_threads
.unwrap_or_else(|| MtmdContextParams::default().n_threads),
..MtmdContextParams::default()
};
match MtmdContext::init_from_file(mmproj_path.to_str().unwrap_or_default(), model, &params)
{
Ok(ctx) => {
tracing::info!(
vision = ctx.support_vision(),
audio = ctx.support_audio(),
"Multimodal context initialized"
);
Some(ctx)
}
Err(e) => {
tracing::warn!(error = %e, "Failed to init multimodal context");
None
}
}
}
}
@ -453,13 +535,11 @@ impl Provider for LocalInferenceProvider {
messages: &[Message],
tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let (_model_path, model_context_limit, model_settings) =
resolve_model_path(&model_config.model_name).ok_or_else(|| {
ProviderError::ExecutionError(format!(
"Model not found: {}",
model_config.model_name
))
let resolved = resolve_model_path(&model_config.model_name).ok_or_else(|| {
ProviderError::ExecutionError(format!("Model not found: {}", model_config.model_name))
})?;
let model_context_limit = resolved.context_limit;
let model_settings = resolved.settings;
// Ensure model is loaded — unload any other models first to free memory.
{
@ -503,6 +583,18 @@ impl Provider for LocalInferenceProvider {
system.to_string()
};
// Extract images for vision-capable models, replacing them with markers.
// For non-vision models, leave messages unchanged (existing strip logic handles them).
let has_vision = resolved.mmproj_path.is_some();
let marker = llama_cpp_2::mtmd::mtmd_default_marker();
let (images, vision_messages): (Vec<ExtractedImage>, Option<Vec<Message>>) = if has_vision {
let (imgs, msgs) = multimodal::extract_images_from_messages(messages, marker);
(imgs, Some(msgs))
} else {
(Vec::new(), None)
};
let effective_messages: &[Message] = vision_messages.as_deref().unwrap_or(messages);
// Build chat messages for the template
let mut chat_messages =
vec![
@ -529,7 +621,7 @@ impl Provider for LocalInferenceProvider {
})?];
}
for msg in messages {
for msg in effective_messages {
let role = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
@ -553,7 +645,10 @@ impl Provider for LocalInferenceProvider {
};
let oai_messages_json = if model_settings.use_jinja || native_tool_calling {
Some(build_openai_messages_json(&system_prompt, messages))
Some(build_openai_messages_json(
&system_prompt,
effective_messages,
))
} else {
None
};
@ -563,6 +658,7 @@ impl Provider for LocalInferenceProvider {
let model_name = model_config.model_name.clone();
let context_limit = model_context_limit;
let settings = model_settings;
let mmproj_path = resolved.mmproj_path.clone();
let log_payload = serde_json::json!({
"system": &system_prompt,
@ -604,8 +700,8 @@ impl Provider for LocalInferenceProvider {
}};
}
let model_guard = model_arc.blocking_lock();
let loaded = match model_guard.as_ref() {
let mut model_guard = model_arc.blocking_lock();
let loaded = match model_guard.as_mut() {
Some(l) => l,
None => {
send_err!(ProviderError::ExecutionError(
@ -614,6 +710,16 @@ impl Provider for LocalInferenceProvider {
}
};
// Lazily initialize the multimodal context if the vision encoder
// was downloaded after the model was loaded.
if !images.is_empty() && loaded.mtmd_ctx.is_none() {
loaded.mtmd_ctx = LocalInferenceProvider::init_mtmd_context(
&loaded.model,
&mmproj_path,
&settings,
);
}
let message_id = Uuid::new_v4().to_string();
let mut gen_ctx = GenerationContext {
@ -626,6 +732,7 @@ impl Provider for LocalInferenceProvider {
message_id: &message_id,
tx: &tx,
log: &mut log,
images: &images,
};
let result = if use_emulator {

View file

@ -29,8 +29,8 @@ use std::borrow::Cow;
use uuid::Uuid;
use super::inference_engine::{
create_and_prefill_context, generation_loop, validate_and_compute_context, GenerationContext,
TokenAction,
create_and_prefill_context, create_and_prefill_multimodal, generation_loop,
validate_and_compute_context, GenerationContext, TokenAction,
};
use super::{finalize_usage, StreamSender, CODE_EXECUTION_TOOL, SHELL_TOOL};
@ -370,26 +370,32 @@ pub(super) fn generate_with_emulated_tools(
ProviderError::ExecutionError(format!("Failed to apply chat template: {}", e))
})?;
let (mut llama_ctx, prompt_token_count, effective_ctx) = if !ctx.images.is_empty() {
create_and_prefill_multimodal(
ctx.loaded,
ctx.runtime,
&prompt,
ctx.images,
ctx.context_limit,
ctx.settings,
)?
} else {
let tokens = ctx
.loaded
.model
.str_to_token(&prompt, AddBos::Never)
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
let (prompt_token_count, effective_ctx) = validate_and_compute_context(
let (ptc, ectx) = validate_and_compute_context(
ctx.loaded,
ctx.runtime,
tokens.len(),
ctx.context_limit,
ctx.settings,
)?;
let mut llama_ctx = create_and_prefill_context(
ctx.loaded,
ctx.runtime,
&tokens,
effective_ctx,
ctx.settings,
)?;
let lctx =
create_and_prefill_context(ctx.loaded, ctx.runtime, &tokens, ectx, ctx.settings)?;
(lctx, ptc, ectx)
};
let message_id = ctx.message_id;
let tx = ctx.tx;

View file

@ -1,9 +1,11 @@
use crate::providers::errors::ProviderError;
use crate::providers::local_inference::local_model_registry::ModelSettings;
use crate::providers::local_inference::multimodal::ExtractedImage;
use crate::providers::utils::RequestLog;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel};
use llama_cpp_2::mtmd::{MtmdBitmap, MtmdContext, MtmdInputText};
use llama_cpp_2::sampling::LlamaSampler;
use std::num::NonZeroU32;
@ -19,11 +21,14 @@ pub(super) struct GenerationContext<'a> {
pub message_id: &'a str,
pub tx: &'a StreamSender,
pub log: &'a mut RequestLog,
pub images: &'a [ExtractedImage],
}
pub(super) struct LoadedModel {
pub model: LlamaModel,
pub template: LlamaChatTemplate,
/// Multimodal context for vision models. None for text-only models.
pub mtmd_ctx: Option<MtmdContext>,
}
/// Estimate the maximum context length that can fit in available accelerator/CPU
@ -33,11 +38,13 @@ pub(super) struct LoadedModel {
pub(super) fn estimate_max_context_for_memory(
model: &LlamaModel,
runtime: &InferenceRuntime,
mmproj_overhead_bytes: u64,
) -> Option<usize> {
let available = super::available_inference_memory_bytes(runtime);
if available == 0 {
let raw_available = super::available_inference_memory_bytes(runtime);
if raw_available == 0 {
return None;
}
let available = raw_available.saturating_sub(mmproj_overhead_bytes);
// Reserve memory for computation scratch buffers (attention, etc.) and other overhead.
// The compute buffer can be 40-50% of the KV cache size for large models, so we
@ -209,7 +216,12 @@ pub(super) fn validate_and_compute_context(
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
) -> Result<(usize, usize), ProviderError> {
let n_ctx_train = loaded.model.n_ctx_train() as usize;
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime);
let mmproj_overhead = if loaded.mtmd_ctx.is_some() {
settings.mmproj_size_bytes
} else {
0
};
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime, mmproj_overhead);
let effective_ctx = effective_context_size(
prompt_token_count,
settings,
@ -261,6 +273,80 @@ pub(super) fn create_and_prefill_context<'model>(
Ok(ctx)
}
/// Tokenize text + images via mtmd and prefill the context.
///
/// Returns the llama context, the number of prompt tokens consumed,
/// and the effective context size.
pub(super) fn create_and_prefill_multimodal<'model>(
loaded: &'model LoadedModel,
runtime: &InferenceRuntime,
prompt_text: &str,
images: &[ExtractedImage],
context_limit: usize,
settings: &ModelSettings,
) -> Result<(llama_cpp_2::context::LlamaContext<'model>, usize, usize), ProviderError> {
let mtmd_ctx = loaded.mtmd_ctx.as_ref().ok_or_else(|| {
ProviderError::ExecutionError(
"This model does not have vision support. Download the vision encoder from \
Settings > Local Inference, or use a text-only message."
.to_string(),
)
})?;
let bitmaps: Vec<MtmdBitmap> = images
.iter()
.map(|img| {
MtmdBitmap::from_buffer(mtmd_ctx, &img.bytes)
.map_err(|e| ProviderError::ExecutionError(format!("Failed to decode image: {e}")))
})
.collect::<Result<_, _>>()?;
let bitmap_refs: Vec<&MtmdBitmap> = bitmaps.iter().collect();
let input_text = MtmdInputText {
text: prompt_text.to_string(),
add_special: true,
parse_special: true,
};
let chunks = mtmd_ctx.tokenize(input_text, &bitmap_refs).map_err(|e| {
ProviderError::ExecutionError(format!("Multimodal tokenization failed: {e}"))
})?;
let prompt_token_count = chunks.total_tokens();
let n_ctx_train = loaded.model.n_ctx_train() as usize;
let mmproj_overhead = settings.mmproj_size_bytes;
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime, mmproj_overhead);
let effective_ctx = effective_context_size(
prompt_token_count,
settings,
context_limit,
n_ctx_train,
memory_max_ctx,
);
let min_generation_headroom = 512;
if prompt_token_count + min_generation_headroom > effective_ctx {
return Err(ProviderError::ContextLengthExceeded(format!(
"Multimodal prompt ({prompt_token_count} tokens including images) exceeds \
context limit ({effective_ctx} tokens)",
)));
}
let ctx_params = build_context_params(effective_ctx as u32, settings);
let llama_ctx = loaded
.model
.new_context(runtime.backend(), ctx_params)
.map_err(|e| ProviderError::ExecutionError(format!("Failed to create context: {e}")))?;
let n_batch = llama_ctx.n_batch() as i32;
let _n_past = chunks
.eval_chunks(mtmd_ctx, &llama_ctx, 0, 0, n_batch, true)
.map_err(|e| ProviderError::ExecutionError(format!("Multimodal eval failed: {e}")))?;
Ok((llama_ctx, prompt_token_count, effective_ctx))
}
/// Action to take after processing a generated token piece.
pub(super) enum TokenAction {
Continue,

View file

@ -9,8 +9,9 @@ use uuid::Uuid;
use super::finalize_usage;
use super::inference_engine::{
context_cap, create_and_prefill_context, estimate_max_context_for_memory, generation_loop,
validate_and_compute_context, GenerationContext, TokenAction,
context_cap, create_and_prefill_context, create_and_prefill_multimodal,
estimate_max_context_for_memory, generation_loop, validate_and_compute_context,
GenerationContext, TokenAction,
};
pub(super) fn generate_with_native_tools(
@ -21,7 +22,13 @@ pub(super) fn generate_with_native_tools(
) -> Result<(), ProviderError> {
let min_generation_headroom = 512;
let n_ctx_train = ctx.loaded.model.n_ctx_train() as usize;
let memory_max_ctx = estimate_max_context_for_memory(&ctx.loaded.model, ctx.runtime);
let mmproj_overhead = if ctx.loaded.mtmd_ctx.is_some() {
ctx.settings.mmproj_size_bytes
} else {
0
};
let memory_max_ctx =
estimate_max_context_for_memory(&ctx.loaded.model, ctx.runtime, mmproj_overhead);
let cap = context_cap(ctx.settings, ctx.context_limit, n_ctx_train, memory_max_ctx);
let token_budget = cap.saturating_sub(min_generation_headroom);
@ -61,6 +68,8 @@ pub(super) fn generate_with_native_tools(
}
};
let estimated_image_tokens = ctx.images.len() * ctx.settings.image_token_estimate;
let template_result = match apply_template(full_tools_json) {
Ok(r) => {
let token_count = ctx
@ -69,7 +78,7 @@ pub(super) fn generate_with_native_tools(
.str_to_token(&r.prompt, AddBos::Never)
.map(|t| t.len())
.unwrap_or(0);
if token_count > token_budget {
if token_count + estimated_image_tokens > token_budget {
apply_template(compact_tools).unwrap_or(r)
} else {
r
@ -85,26 +94,32 @@ pub(super) fn generate_with_native_tools(
None,
);
let (mut llama_ctx, prompt_token_count, effective_ctx) = if !ctx.images.is_empty() {
create_and_prefill_multimodal(
ctx.loaded,
ctx.runtime,
&template_result.prompt,
ctx.images,
ctx.context_limit,
ctx.settings,
)?
} else {
let tokens = ctx
.loaded
.model
.str_to_token(&template_result.prompt, AddBos::Never)
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
let (prompt_token_count, effective_ctx) = validate_and_compute_context(
let (ptc, ectx) = validate_and_compute_context(
ctx.loaded,
ctx.runtime,
tokens.len(),
ctx.context_limit,
ctx.settings,
)?;
let mut llama_ctx = create_and_prefill_context(
ctx.loaded,
ctx.runtime,
&tokens,
effective_ctx,
ctx.settings,
)?;
let lctx =
create_and_prefill_context(ctx.loaded, ctx.runtime, &tokens, ectx, ctx.settings)?;
(lctx, ptc, ectx)
};
let message_id = ctx.message_id;
let tx = ctx.tx;

View file

@ -62,12 +62,27 @@ pub struct ModelSettings {
pub use_jinja: bool,
#[serde(default = "default_true")]
pub enable_thinking: bool,
/// Whether this model architecture supports vision input.
/// Derived from the featured model table, not user-configurable.
#[serde(default)]
pub vision_capable: bool,
/// Estimated tokens per image for budget planning before mtmd tokenization.
/// The actual count is determined after tokenization via `chunks.total_tokens()`.
#[serde(default = "default_image_token_estimate")]
pub image_token_estimate: usize,
/// Size of the mmproj file in bytes, used for memory accounting.
#[serde(default)]
pub mmproj_size_bytes: u64,
}
fn default_true() -> bool {
true
}
fn default_image_token_estimate() -> usize {
256
}
fn default_repeat_penalty() -> f32 {
1.0
}
@ -94,41 +109,75 @@ impl Default for ModelSettings {
native_tool_calling: false,
use_jinja: false,
enable_thinking: true,
vision_capable: false,
image_token_estimate: default_image_token_estimate(),
mmproj_size_bytes: 0,
}
}
}
/// HuggingFace repo + filename for multimodal projection weights (vision encoder).
pub struct MmprojSpec {
pub repo: &'static str,
pub filename: &'static str,
}
impl MmprojSpec {
/// Local path for this mmproj, namespaced by repo to avoid collisions
/// between different models that use the same filename.
pub fn local_path(&self) -> std::path::PathBuf {
let repo_name = self.repo.split('/').next_back().unwrap_or(self.repo);
Paths::in_data_dir("models")
.join(repo_name)
.join(self.filename)
}
}
pub struct FeaturedModel {
/// HuggingFace spec in "author/repo-GGUF:quantization" format.
pub spec: &'static str,
/// Whether this model's GGUF template supports native tool calling via llama.cpp.
pub native_tool_calling: bool,
/// Multimodal projection weights spec. None for text-only models.
pub mmproj: Option<MmprojSpec>,
}
pub const FEATURED_MODELS: &[FeaturedModel] = &[
FeaturedModel {
spec: "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",
native_tool_calling: false,
mmproj: None,
},
FeaturedModel {
spec: "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",
native_tool_calling: false,
mmproj: None,
},
FeaturedModel {
spec: "bartowski/Hermes-2-Pro-Mistral-7B-GGUF:Q4_K_M",
native_tool_calling: false,
mmproj: None,
},
FeaturedModel {
spec: "bartowski/Mistral-Small-24B-Instruct-2501-GGUF:Q4_K_M",
native_tool_calling: false,
mmproj: None,
},
FeaturedModel {
spec: "unsloth/gemma-4-E4B-it-GGUF:Q4_K_M",
native_tool_calling: true,
mmproj: Some(MmprojSpec {
repo: "unsloth/gemma-4-E4B-it-GGUF",
filename: "mmproj-BF16.gguf",
}),
},
FeaturedModel {
spec: "unsloth/gemma-4-26B-A4B-it-GGUF:Q4_K_M",
native_tool_calling: true,
mmproj: Some(MmprojSpec {
repo: "unsloth/gemma-4-26B-A4B-it-GGUF",
filename: "mmproj-BF16.gguf",
}),
},
];
@ -144,10 +193,25 @@ pub fn default_settings_for_model(model_id: &str) -> ModelSettings {
});
ModelSettings {
native_tool_calling: featured.is_some_and(|m| m.native_tool_calling),
vision_capable: featured.is_some_and(|m| m.mmproj.is_some()),
..ModelSettings::default()
}
}
/// Look up the `MmprojSpec` for a featured model by its model ID.
pub fn featured_mmproj_spec(model_id: &str) -> Option<&'static MmprojSpec> {
use super::hf_models::parse_model_spec;
let model_repo = model_id.split(':').next().unwrap_or(model_id);
FEATURED_MODELS.iter().find_map(|m| {
if let Ok((repo_id, _quant)) = parse_model_spec(m.spec) {
if repo_id == model_repo {
return m.mmproj.as_ref();
}
}
None
})
}
/// Check if a model ID corresponds to a featured model.
pub fn is_featured_model(model_id: &str) -> bool {
use super::hf_models::parse_model_spec;
@ -181,9 +245,42 @@ pub struct LocalModelEntry {
pub settings: ModelSettings,
#[serde(default)]
pub size_bytes: u64,
/// Local path to the multimodal projection GGUF (vision encoder).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mmproj_path: Option<PathBuf>,
/// Download URL for the mmproj file.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mmproj_source_url: Option<String>,
/// Size of the mmproj file in bytes.
#[serde(default)]
pub mmproj_size_bytes: u64,
}
impl LocalModelEntry {
/// Populate mmproj metadata and vision settings from the featured model
/// table if this model's repo has a known vision encoder.
pub fn enrich_with_featured_mmproj(&mut self) {
if let Some(mmproj) = featured_mmproj_spec(&self.id) {
let path = mmproj.local_path();
if self.mmproj_path.as_ref() != Some(&path) {
self.mmproj_path = Some(path.clone());
self.mmproj_source_url = Some(format!(
"https://huggingface.co/{}/resolve/main/{}",
mmproj.repo, mmproj.filename
));
}
self.settings.vision_capable = true;
if self.mmproj_size_bytes == 0 || self.settings.mmproj_size_bytes == 0 {
if let Ok(meta) = std::fs::metadata(&path) {
self.mmproj_size_bytes = meta.len();
self.settings.mmproj_size_bytes = meta.len();
}
}
}
let defaults = default_settings_for_model(&self.id);
self.settings.native_tool_calling = defaults.native_tool_calling;
}
pub fn is_downloaded(&self) -> bool {
self.local_path.exists()
}
@ -219,6 +316,36 @@ impl LocalModelEntry {
ModelDownloadStatus::NotDownloaded
}
pub fn has_vision(&self) -> bool {
self.mmproj_path.as_ref().is_some_and(|p| p.exists())
}
pub fn mmproj_download_status(&self) -> ModelDownloadStatus {
if let Some(path) = &self.mmproj_path {
if path.exists() {
return ModelDownloadStatus::Downloaded;
}
} else {
return ModelDownloadStatus::NotDownloaded;
}
let download_id = format!("{}-mmproj", self.id);
let manager = get_download_manager();
if let Some(progress) = manager.get_progress(&download_id) {
return match progress.status {
DownloadStatus::Downloading => ModelDownloadStatus::Downloading {
progress_percent: progress.progress_percent,
bytes_downloaded: progress.bytes_downloaded,
total_bytes: progress.total_bytes,
speed_bps: progress.speed_bps.unwrap_or(0),
},
_ => ModelDownloadStatus::NotDownloaded,
};
}
ModelDownloadStatus::NotDownloaded
}
pub fn file_size(&self) -> u64 {
if self.size_bytes > 0 {
return self.size_bytes;
@ -290,8 +417,9 @@ impl LocalModelRegistry {
pub fn sync_with_featured(&mut self, featured_entries: Vec<LocalModelEntry>) {
let mut changed = false;
for entry in featured_entries {
for mut entry in featured_entries {
if !self.models.iter().any(|m| m.id == entry.id) {
entry.enrich_with_featured_mmproj();
self.models.push(entry);
changed = true;
}
@ -309,7 +437,8 @@ impl LocalModelRegistry {
}
}
pub fn add_model(&mut self, entry: LocalModelEntry) -> Result<()> {
pub fn add_model(&mut self, mut entry: LocalModelEntry) -> Result<()> {
entry.enrich_with_featured_mmproj();
if let Some(existing) = self.models.iter_mut().find(|m| m.id == entry.id) {
*existing = entry;
} else {
@ -348,6 +477,10 @@ impl LocalModelRegistry {
pub fn list_models(&self) -> &[LocalModelEntry] {
&self.models
}
pub fn list_models_mut(&mut self) -> &mut [LocalModelEntry] {
&mut self.models
}
}
/// Generate a unique ID for a model from its repo_id and quantization.

View file

@ -0,0 +1,336 @@
use base64::prelude::*;
use serde_json::Value;
use crate::conversation::message::{Message, MessageContent};
use crate::providers::errors::ProviderError;
#[derive(Debug)]
pub struct ExtractedImage {
pub bytes: Vec<u8>,
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct MultimodalMessages {
pub messages_json: String,
pub images: Vec<ExtractedImage>,
}
/// Walk the OpenAI-format messages JSON array. For each content part with
/// `type: "image_url"`, decode the base64 data URL, store the raw bytes,
/// and replace the part with `{"type": "text", "text": "<marker>"}`.
///
/// Returns the modified JSON string and the extracted images in order.
#[allow(dead_code)]
pub fn extract_images_from_messages_json(
messages_json: &str,
marker: &str,
) -> Result<MultimodalMessages, ProviderError> {
let mut messages: Vec<Value> = serde_json::from_str(messages_json).map_err(|e| {
ProviderError::ExecutionError(format!("Failed to parse messages JSON: {e}"))
})?;
let mut images = Vec::new();
for msg in messages.iter_mut() {
let Some(content) = msg.get_mut("content").and_then(|c| c.as_array_mut()) else {
continue;
};
for part in content.iter_mut() {
if part.get("type").and_then(|t| t.as_str()) != Some("image_url") {
continue;
}
let url = part
.get("image_url")
.and_then(|obj| obj.get("url"))
.and_then(|u| u.as_str())
.unwrap_or_default();
if url.starts_with("http://") || url.starts_with("https://") {
return Err(ProviderError::ExecutionError(
"Remote image URLs are not supported with local inference. \
Please attach the image directly."
.to_string(),
));
}
let base64_data = url.split_once(',').map_or(url, |(_, data)| data);
let bytes = BASE64_STANDARD.decode(base64_data).map_err(|e| {
ProviderError::ExecutionError(format!("Failed to decode base64 image: {e}"))
})?;
images.push(ExtractedImage { bytes });
*part = serde_json::json!({
"type": "text",
"text": marker,
});
}
}
let messages_json = serde_json::to_string(&messages)
.map_err(|e| ProviderError::ExecutionError(format!("Failed to serialize messages: {e}")))?;
Ok(MultimodalMessages {
messages_json,
images,
})
}
/// Scan messages for `MessageContent::Image` entries. Return the extracted image
/// bytes and a new message list with images replaced by text marker placeholders.
pub fn extract_images_from_messages(
messages: &[Message],
marker: &str,
) -> (Vec<ExtractedImage>, Vec<Message>) {
let mut images = Vec::new();
let mut new_messages = Vec::with_capacity(messages.len());
for msg in messages {
let mut new_content = Vec::with_capacity(msg.content.len());
for content in &msg.content {
match content {
MessageContent::Image(img) => {
if let Ok(bytes) = BASE64_STANDARD.decode(&img.data) {
images.push(ExtractedImage { bytes });
new_content.push(MessageContent::text(marker));
} else {
new_content.push(MessageContent::text(
"[Image attached — failed to decode image data]",
));
}
}
other => new_content.push(other.clone()),
}
}
new_messages.push(Message {
role: msg.role.clone(),
content: new_content,
..msg.clone()
});
}
(images, new_messages)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_test_messages_json(parts: Vec<Value>) -> String {
serde_json::to_string(&vec![json!({
"role": "user",
"content": parts,
})])
.unwrap()
}
fn tiny_png_base64() -> String {
// 1x1 red PNG
let bytes: &[u8] = &[
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48,
0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00,
0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08,
0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x36, 0x28, 0x19,
0x00, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
];
BASE64_STANDARD.encode(bytes)
}
#[test]
fn test_extract_images_replaces_image_url_with_marker() {
let b64 = tiny_png_base64();
let json = make_test_messages_json(vec![json!({
"type": "image_url",
"image_url": {"url": format!("data:image/png;base64,{b64}")}
})]);
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
assert_eq!(result.images.len(), 1);
assert!(!result.images[0].bytes.is_empty());
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
let content = parsed[0]["content"].as_array().unwrap();
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "<__media__>");
}
#[test]
fn test_extract_images_preserves_text_parts() {
let json = make_test_messages_json(vec![json!({
"type": "text",
"text": "Hello world"
})]);
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
assert!(result.images.is_empty());
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
let content = parsed[0]["content"].as_array().unwrap();
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "Hello world");
}
#[test]
fn test_extract_images_multiple_images() {
let b64 = tiny_png_base64();
let json = make_test_messages_json(vec![
json!({"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}),
json!({"type": "text", "text": "describe both"}),
json!({"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}),
]);
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
assert_eq!(result.images.len(), 2);
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
let content = parsed[0]["content"].as_array().unwrap();
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "<__media__>");
assert_eq!(content[1]["type"], "text");
assert_eq!(content[1]["text"], "describe both");
assert_eq!(content[2]["type"], "text");
assert_eq!(content[2]["text"], "<__media__>");
}
#[test]
fn test_extract_images_no_images() {
let json = make_test_messages_json(vec![json!({
"type": "text",
"text": "just text"
})]);
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
assert!(result.images.is_empty());
// JSON should be equivalent
let original: Vec<Value> = serde_json::from_str(&json).unwrap();
let result_parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
assert_eq!(original, result_parsed);
}
#[test]
fn test_extract_images_http_url_rejected() {
let json = make_test_messages_json(vec![json!({
"type": "image_url",
"image_url": {"url": "https://example.com/image.png"}
})]);
let result = extract_images_from_messages_json(&json, "<__media__>");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Remote image URLs are not supported"));
}
#[test]
fn test_extract_images_mixed_content() {
let b64 = tiny_png_base64();
// Two messages: first with text+image, second with just text
let json = serde_json::to_string(&vec![
json!({
"role": "user",
"content": [
{"type": "text", "text": "What is this?"},
{"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}},
]
}),
json!({
"role": "assistant",
"content": [
{"type": "text", "text": "It looks like a red pixel."},
]
}),
])
.unwrap();
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
assert_eq!(result.images.len(), 1);
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
// First message: text preserved, image replaced
let content0 = parsed[0]["content"].as_array().unwrap();
assert_eq!(content0[0]["text"], "What is this?");
assert_eq!(content0[1]["text"], "<__media__>");
// Second message unchanged
let content1 = parsed[1]["content"].as_array().unwrap();
assert_eq!(content1[0]["text"], "It looks like a red pixel.");
}
// --- Tests for extract_images_from_messages (Message-based) ---
#[test]
fn test_messages_extract_replaces_image_with_marker() {
let b64 = tiny_png_base64();
let messages = vec![Message::user().with_image(b64, "image/png")];
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
assert_eq!(images.len(), 1);
assert!(!images[0].bytes.is_empty());
assert_eq!(new_msgs.len(), 1);
assert_eq!(new_msgs[0].as_concat_text(), "<__media__>");
}
#[test]
fn test_messages_extract_preserves_text() {
let messages = vec![Message::user().with_text("Hello world")];
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
assert!(images.is_empty());
assert_eq!(new_msgs[0].as_concat_text(), "Hello world");
}
#[test]
fn test_messages_extract_multiple_images() {
let b64 = tiny_png_base64();
let messages = vec![Message::user()
.with_image(b64.clone(), "image/png")
.with_text("describe both")
.with_image(b64, "image/png")];
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
assert_eq!(images.len(), 2);
assert_eq!(new_msgs[0].content.len(), 3);
assert_eq!(
new_msgs[0].as_concat_text(),
"<__media__>\ndescribe both\n<__media__>"
);
}
#[test]
fn test_messages_extract_no_images() {
let messages = vec![Message::user().with_text("just text")];
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
assert!(images.is_empty());
assert_eq!(new_msgs[0].as_concat_text(), "just text");
}
#[test]
fn test_messages_extract_invalid_base64() {
let messages = vec![Message::user().with_image("not-valid-base64!!!", "image/png")];
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
assert!(images.is_empty());
assert!(new_msgs[0].as_concat_text().contains("failed to decode"));
}
#[test]
fn test_messages_extract_mixed_content() {
let b64 = tiny_png_base64();
let messages = vec![
Message::user()
.with_text("What is this?")
.with_image(b64, "image/png"),
Message::assistant().with_text("It looks like a red pixel."),
];
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
assert_eq!(images.len(), 1);
assert_eq!(new_msgs.len(), 2);
assert_eq!(new_msgs[0].as_concat_text(), "What is this?\n<__media__>");
assert_eq!(new_msgs[1].as_concat_text(), "It looks like a red pixel.");
}
}

View file

@ -9,7 +9,11 @@
//!
//! Run with a specific model:
//! TEST_MODEL="bartowski/Qwen_Qwen3-32B-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration -- --ignored
//!
//! Run vision tests (requires a vision-capable model like gemma-4):
//! TEST_VISION_MODEL="unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration test_local_inference_vision -- --ignored
use base64::prelude::*;
use futures::StreamExt;
use goose::conversation::message::Message;
use goose::model::ModelConfig;
@ -93,3 +97,130 @@ async fn test_local_inference_large_prompt() {
text.len()
);
}
fn vision_test_model() -> Option<String> {
std::env::var("TEST_VISION_MODEL").ok()
}
/// Generate a small solid-colour 2x2 red PNG as raw bytes.
fn tiny_red_png() -> Vec<u8> {
vec![
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, // 1x1
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, // RGB, 8-bit
0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, // IDAT chunk
0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x36, 0x28, 0x19,
0x00, // compressed pixel data
0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, // IEND
]
}
/// Test that a vision-capable local model can process a message with an embedded image
/// and produce a text response without crashing.
///
/// Requires TEST_VISION_MODEL to be set to a downloaded vision model.
/// Example:
/// TEST_VISION_MODEL="unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" \
/// cargo test -p goose --test local_inference_integration test_local_inference_vision -- --ignored
#[tokio::test]
#[ignore]
async fn test_local_inference_vision_produces_output() {
let model_id = match vision_test_model() {
Some(id) => id,
None => {
eprintln!(
"Skipping vision test: TEST_VISION_MODEL not set. \
Set it to a vision-capable model like unsloth/gemma-4-E4B-it-GGUF:Q4_K_M"
);
return;
}
};
let model_config = ModelConfig::new(&model_id).expect("valid model config");
let provider = create("local", model_config.clone(), Vec::new())
.await
.expect("provider creation should succeed");
let image_bytes = tiny_red_png();
let image_b64 = BASE64_STANDARD.encode(&image_bytes);
let system = "You are a helpful assistant. Describe images briefly.";
let messages = vec![Message::user()
.with_text("What color is this image?")
.with_image(image_b64, "image/png")];
let mut stream = provider
.stream(&model_config, "test-vision-session", system, &messages, &[])
.await
.expect("stream should start for vision input");
let mut got_text = false;
let mut collected_text = String::new();
while let Some(result) = stream.next().await {
let (msg, _usage) = result.expect("stream item should be Ok");
if let Some(m) = msg {
got_text = true;
collected_text.push_str(&m.as_concat_text());
}
}
assert!(
got_text,
"vision stream should produce at least one text message"
);
assert!(
!collected_text.is_empty(),
"vision response should contain text"
);
println!("Vision response: {collected_text}");
}
/// Test that sending an image to a text-only model produces a clear error
/// rather than crashing.
#[tokio::test]
#[ignore]
async fn test_local_inference_vision_text_only_model_graceful() {
let model_config = ModelConfig::new(&test_model()).expect("valid model config");
let provider = create("local", model_config.clone(), Vec::new())
.await
.expect("provider creation should succeed");
let image_bytes = tiny_red_png();
let image_b64 = BASE64_STANDARD.encode(&image_bytes);
let system = "You are a helpful assistant.";
let messages = vec![Message::user()
.with_text("What is this?")
.with_image(image_b64, "image/png")];
let mut stream = provider
.stream(&model_config, "test-session", system, &messages, &[])
.await
.expect("stream should start");
// The stream should either produce a response with the image stripped
// (placeholder text) or produce an error — but it must not crash.
let mut completed = false;
while let Some(result) = stream.next().await {
match result {
Ok(_) => completed = true,
Err(e) => {
// An error about missing vision support is acceptable
let err_msg = e.to_string();
assert!(
err_msg.contains("vision") || err_msg.contains("image"),
"error should mention vision/image support, got: {err_msg}"
);
completed = true;
break;
}
}
}
assert!(
completed,
"stream should complete without crashing when images sent to text-only model"
);
}

View file

@ -317,6 +317,31 @@ prompt: |
Log results to: {{ workspace_dir }}/phase3_delegation.md
{% endif %}
{% if test_phases == "all" or "vision" in test_phases %}
## 📷 PHASE 3B: Local Inference Vision Testing
**Prerequisites**: A vision-capable local model must be downloaded (e.g., gemma-4-E4B).
Skip this phase if no local vision model is available.
### Vision Smoke Test
1. Create a small test image:
```
python3 -c "import struct, zlib; raw=b'\x00\xff\x00\x00'; d=zlib.compress(raw); ihdr=b'\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00'; print('Created test.png')"
```
Or simply create a 1-pixel PNG test image using available tools.
2. Verify the test image file exists and is valid.
3. Send a message to the local vision model referencing the test image.
4. Verify the model responds with text (not an error or crash).
5. Verify the response acknowledges the image content.
### Vision Error Handling Test
1. If a text-only local model is available, send it a message with an image attached.
2. Verify it responds gracefully (either with a placeholder message or a clear error),
not with a crash or FFI error.
Log results to: {{ workspace_dir }}/phase3b_vision.md
{% endif %}
{% if test_phases == "all" or "advanced" in test_phases %}
## 🔬 PHASE 4: Advanced Testing

View file

@ -5724,7 +5724,8 @@
"size_bytes",
"status",
"recommended",
"settings"
"settings",
"vision_capable"
],
"properties": {
"filename": {
@ -5733,6 +5734,14 @@
"id": {
"type": "string"
},
"mmproj_status": {
"allOf": [
{
"$ref": "#/components/schemas/ModelDownloadStatus"
}
],
"nullable": true
},
"quantization": {
"type": "string"
},
@ -5752,6 +5761,9 @@
},
"status": {
"$ref": "#/components/schemas/ModelDownloadStatus"
},
"vision_capable": {
"type": "boolean"
}
}
},
@ -6497,11 +6509,22 @@
"type": "number",
"format": "float"
},
"image_token_estimate": {
"type": "integer",
"description": "Estimated tokens per image for budget planning before mtmd tokenization.\nThe actual count is determined after tokenization via `chunks.total_tokens()`.",
"minimum": 0
},
"max_output_tokens": {
"type": "integer",
"nullable": true,
"minimum": 0
},
"mmproj_size_bytes": {
"type": "integer",
"format": "int64",
"description": "Size of the mmproj file in bytes, used for memory accounting.",
"minimum": 0
},
"n_batch": {
"type": "integer",
"format": "int32",
@ -6542,6 +6565,10 @@
},
"use_mlock": {
"type": "boolean"
},
"vision_capable": {
"type": "boolean",
"description": "Whether this model architecture supports vision input.\nDerived from the featured model table, not user-configurable."
}
}
},

View file

@ -612,12 +612,14 @@ export type LoadedProvider = {
export type LocalModelResponse = {
filename: string;
id: string;
mmproj_status?: ModelDownloadStatus | null;
quantization: string;
recommended: boolean;
repo_id: string;
settings: ModelSettings;
size_bytes: number;
status: ModelDownloadStatus;
vision_capable: boolean;
};
/**
@ -821,7 +823,16 @@ export type ModelSettings = {
enable_thinking?: boolean;
flash_attention?: boolean | null;
frequency_penalty?: number;
/**
* Estimated tokens per image for budget planning before mtmd tokenization.
* The actual count is determined after tokenization via `chunks.total_tokens()`.
*/
image_token_estimate?: number;
max_output_tokens?: number | null;
/**
* Size of the mmproj file in bytes, used for memory accounting.
*/
mmproj_size_bytes?: number;
n_batch?: number | null;
n_gpu_layers?: number | null;
n_threads?: number | null;
@ -832,6 +843,11 @@ export type ModelSettings = {
sampling?: SamplingConfig;
use_jinja?: boolean;
use_mlock?: boolean;
/**
* Whether this model architecture supports vision input.
* Derived from the featured model table, not user-configurable.
*/
vision_capable?: boolean;
};
export type ModelTemplate = {

View file

@ -1,5 +1,5 @@
import { useState, useEffect, useCallback, useRef } from 'react';
import { Download, Trash2, X, ChevronDown, ChevronUp, Settings2 } from 'lucide-react';
import { Download, Trash2, X, ChevronDown, ChevronUp, Settings2, Eye } from 'lucide-react';
import { Button } from '../../ui/button';
import { useModelAndProvider } from '../../ModelAndProviderContext';
import { defineMessages, useIntl } from '../../../i18n';
@ -83,8 +83,57 @@ const i18n = defineMessages({
id: 'localInferenceSettings.modelSettingsTitle',
defaultMessage: 'Model settings',
},
vision: {
id: 'localInferenceSettings.vision',
defaultMessage: 'Vision',
},
visionEncoderDownloading: {
id: 'localInferenceSettings.visionEncoderDownloading',
defaultMessage: 'Vision encoder downloading…',
},
visionEncoderNotDownloaded: {
id: 'localInferenceSettings.visionEncoderNotDownloaded',
defaultMessage: 'Vision encoder not downloaded',
},
});
const VisionBadge = ({ model, intl }: { model: LocalModelResponse; intl: ReturnType<typeof useIntl> }) => {
if (!model.vision_capable) return null;
const mmproj = model.mmproj_status;
const isDownloaded = mmproj?.state === 'Downloaded';
const isDownloading = mmproj?.state === 'Downloading';
if (isDownloaded) {
return (
<span className="inline-flex items-center gap-1 text-xs text-green-400 bg-green-500/10 px-2 py-0.5 rounded">
<Eye className="w-3 h-3" />
{intl.formatMessage(i18n.vision)}
</span>
);
}
if (isDownloading) {
const percent = mmproj && 'progress_percent' in mmproj
? Math.round(mmproj.progress_percent)
: null;
return (
<span className="inline-flex items-center gap-1 text-xs text-yellow-400 bg-yellow-500/10 px-2 py-0.5 rounded">
<Eye className="w-3 h-3" />
{intl.formatMessage(i18n.visionEncoderDownloading)}
{percent != null && ` ${percent}%`}
</span>
);
}
return (
<span className="inline-flex items-center gap-1 text-xs text-text-muted bg-background-subtle px-2 py-0.5 rounded">
<Eye className="w-3 h-3" />
{intl.formatMessage(i18n.vision)}
</span>
);
};
const formatBytes = (bytes: number): string => {
if (bytes < 1024) return `${bytes}B`;
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)}KB`;
@ -128,6 +177,19 @@ export const LocalInferenceSettings = () => {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
// Poll model list while any vision encoder is downloading
useEffect(() => {
const hasDownloadingMmproj = models.some(
(m) => m.vision_capable && m.mmproj_status?.state === 'Downloading'
);
if (!hasDownloadingMmproj) return;
const interval = setInterval(() => {
loadModels();
}, 2000);
return () => clearInterval(interval);
}, [models, loadModels]);
const selectModel = async (modelId: string) => {
try {
await setConfigProvider({
@ -365,6 +427,7 @@ export const LocalInferenceSettings = () => {
{intl.formatMessage(i18n.recommended)}
</span>
)}
<VisionBadge model={model} intl={intl} />
</div>
<div className="flex items-center gap-1">
<Button
@ -414,6 +477,7 @@ export const LocalInferenceSettings = () => {
{intl.formatMessage(i18n.recommended)}
</span>
)}
<VisionBadge model={model} intl={intl} />
</div>
</div>
<Button

View file

@ -1838,6 +1838,15 @@
"localInferenceSettings.title": {
"defaultMessage": "Local Inference Models"
},
"localInferenceSettings.vision": {
"defaultMessage": "Vision"
},
"localInferenceSettings.visionEncoderDownloading": {
"defaultMessage": "Vision encoder downloading\u2026"
},
"localInferenceSettings.visionEncoderNotDownloaded": {
"defaultMessage": "Vision encoder not downloaded"
},
"localModelManager.active": {
"defaultMessage": "Active"
},