mirror of
https://github.com/block/goose.git
synced 2026-04-26 10:40:45 +00:00
Add vision/image support for local inference models (#8442)
Some checks failed
Canary / Prepare Version (push) Waiting to run
Canary / build-cli (push) Blocked by required conditions
Canary / Upload Install Script (push) Blocked by required conditions
Canary / bundle-desktop (push) Blocked by required conditions
Canary / bundle-desktop-intel (push) Blocked by required conditions
Canary / bundle-desktop-linux (push) Blocked by required conditions
Canary / bundle-desktop-windows (push) Blocked by required conditions
Canary / Release (push) Blocked by required conditions
Unused Dependencies / machete (push) Waiting to run
CI / changes (push) Waiting to run
CI / Check Rust Code Format (push) Blocked by required conditions
CI / Build and Test Rust Project (push) Blocked by required conditions
CI / Build Rust Project on Windows (push) Waiting to run
CI / Lint Rust Code (push) Blocked by required conditions
CI / Check Generated Schemas are Up-to-Date (push) Blocked by required conditions
CI / Test and Lint Electron Desktop App (push) Blocked by required conditions
Live Provider Tests / check-fork (push) Waiting to run
Live Provider Tests / changes (push) Blocked by required conditions
Live Provider Tests / Build Binary (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (Code Execution) (push) Blocked by required conditions
Live Provider Tests / Compaction Tests (push) Blocked by required conditions
Live Provider Tests / goose server HTTP integration tests (push) Blocked by required conditions
Publish Docker Image / docker (push) Waiting to run
Scorecard supply-chain security / Scorecard analysis (push) Waiting to run
Cargo Deny / deny (push) Has been cancelled
Some checks failed
Canary / Prepare Version (push) Waiting to run
Canary / build-cli (push) Blocked by required conditions
Canary / Upload Install Script (push) Blocked by required conditions
Canary / bundle-desktop (push) Blocked by required conditions
Canary / bundle-desktop-intel (push) Blocked by required conditions
Canary / bundle-desktop-linux (push) Blocked by required conditions
Canary / bundle-desktop-windows (push) Blocked by required conditions
Canary / Release (push) Blocked by required conditions
Unused Dependencies / machete (push) Waiting to run
CI / changes (push) Waiting to run
CI / Check Rust Code Format (push) Blocked by required conditions
CI / Build and Test Rust Project (push) Blocked by required conditions
CI / Build Rust Project on Windows (push) Waiting to run
CI / Lint Rust Code (push) Blocked by required conditions
CI / Check Generated Schemas are Up-to-Date (push) Blocked by required conditions
CI / Test and Lint Electron Desktop App (push) Blocked by required conditions
Live Provider Tests / check-fork (push) Waiting to run
Live Provider Tests / changes (push) Blocked by required conditions
Live Provider Tests / Build Binary (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (Code Execution) (push) Blocked by required conditions
Live Provider Tests / Compaction Tests (push) Blocked by required conditions
Live Provider Tests / goose server HTTP integration tests (push) Blocked by required conditions
Publish Docker Image / docker (push) Waiting to run
Scorecard supply-chain security / Scorecard analysis (push) Waiting to run
Cargo Deny / deny (push) Has been cancelled
Signed-off-by: jh-block <jhugo@block.xyz>
This commit is contained in:
parent
5fa2a8b821
commit
de317d5445
15 changed files with 1181 additions and 88 deletions
|
|
@ -1604,6 +1604,9 @@ async fn handle_local_models_command(command: LocalModelsCommand) -> Result<()>
|
|||
source_url: file.download_url.clone(),
|
||||
settings: Default::default(),
|
||||
size_bytes: file.size_bytes,
|
||||
mmproj_path: None,
|
||||
mmproj_source_url: None,
|
||||
mmproj_size_bytes: 0,
|
||||
};
|
||||
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use crate::routes::errors::ErrorResponse;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
|
|
@ -13,9 +15,9 @@ use goose::providers::local_inference::{
|
|||
available_inference_memory_bytes,
|
||||
hf_models::{resolve_model_spec, HfGgufFile},
|
||||
local_model_registry::{
|
||||
default_settings_for_model, get_registry, is_featured_model, model_id_from_repo,
|
||||
LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus, ModelSettings,
|
||||
FEATURED_MODELS,
|
||||
default_settings_for_model, featured_mmproj_spec, get_registry, is_featured_model,
|
||||
model_id_from_repo, LocalModelEntry, ModelDownloadStatus as RegistryDownloadStatus,
|
||||
ModelSettings, FEATURED_MODELS,
|
||||
},
|
||||
recommend_local_model,
|
||||
};
|
||||
|
|
@ -47,10 +49,14 @@ pub struct LocalModelResponse {
|
|||
pub status: ModelDownloadStatus,
|
||||
pub recommended: bool,
|
||||
pub settings: ModelSettings,
|
||||
pub vision_capable: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mmproj_status: Option<ModelDownloadStatus>,
|
||||
}
|
||||
|
||||
async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
||||
let mut entries_to_add = Vec::new();
|
||||
let mut mmproj_downloads_needed: Vec<(String, String, PathBuf)> = Vec::new();
|
||||
|
||||
for featured in FEATURED_MODELS {
|
||||
let (repo_id, quantization) = match hf_models::parse_model_spec(featured.spec) {
|
||||
|
|
@ -64,8 +70,27 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
|||
let registry = get_registry()
|
||||
.lock()
|
||||
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
||||
if registry.has_model(&model_id) {
|
||||
continue;
|
||||
if let Some(existing) = registry.get_model(&model_id) {
|
||||
let needs_backfill = existing.mmproj_path.is_none() && featured.mmproj.is_some();
|
||||
let needs_download = existing.is_downloaded()
|
||||
&& featured.mmproj.is_some()
|
||||
&& !existing.mmproj_path.as_ref().is_some_and(|p| p.exists());
|
||||
|
||||
if needs_download {
|
||||
if let Some(mmproj) = featured.mmproj.as_ref() {
|
||||
let path = mmproj.local_path();
|
||||
let url = format!(
|
||||
"https://huggingface.co/{}/resolve/main/{}",
|
||||
mmproj.repo, mmproj.filename
|
||||
);
|
||||
mmproj_downloads_needed.push((model_id.clone(), url, path));
|
||||
}
|
||||
}
|
||||
|
||||
if !needs_backfill {
|
||||
continue;
|
||||
}
|
||||
// Fall through to build the entry for sync_with_featured backfill
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -91,6 +116,8 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
|||
|
||||
let local_path = Paths::in_data_dir("models").join(&hf_file.filename);
|
||||
|
||||
// enrich_with_featured_mmproj is called by sync_with_featured/add_model,
|
||||
// so we don't need to populate mmproj fields here.
|
||||
entries_to_add.push(LocalModelEntry {
|
||||
id: model_id.clone(),
|
||||
repo_id,
|
||||
|
|
@ -100,14 +127,58 @@ async fn ensure_featured_models_in_registry() -> Result<(), ErrorResponse> {
|
|||
source_url: hf_file.download_url,
|
||||
settings: default_settings_for_model(&model_id),
|
||||
size_bytes: hf_file.size_bytes,
|
||||
mmproj_path: None,
|
||||
mmproj_source_url: None,
|
||||
mmproj_size_bytes: 0,
|
||||
});
|
||||
}
|
||||
|
||||
if !entries_to_add.is_empty() {
|
||||
{
|
||||
let mut registry = get_registry()
|
||||
.lock()
|
||||
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
||||
registry.sync_with_featured(entries_to_add);
|
||||
|
||||
if !entries_to_add.is_empty() {
|
||||
registry.sync_with_featured(entries_to_add);
|
||||
}
|
||||
|
||||
// Backfill mmproj data for all registry models and collect any
|
||||
// needed mmproj downloads for models already on disk.
|
||||
for model in registry.list_models_mut() {
|
||||
model.enrich_with_featured_mmproj();
|
||||
if model.is_downloaded() {
|
||||
if let Some(mmproj) = featured_mmproj_spec(&model.id) {
|
||||
let path = mmproj.local_path();
|
||||
if !path.exists() {
|
||||
let url = format!(
|
||||
"https://huggingface.co/{}/resolve/main/{}",
|
||||
mmproj.repo, mmproj.filename
|
||||
);
|
||||
mmproj_downloads_needed.push((model.id.clone(), url, path));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = registry.save();
|
||||
}
|
||||
|
||||
// Auto-download mmproj files for models that are already downloaded.
|
||||
// Deduplicate by path since multiple quants share one mmproj file.
|
||||
let dm = get_download_manager();
|
||||
let mut started_paths = std::collections::HashSet::new();
|
||||
for (model_id, url, path) in mmproj_downloads_needed {
|
||||
if !path.exists() && started_paths.insert(path.clone()) {
|
||||
let download_id = format!("{}-mmproj", model_id);
|
||||
let dominated_by_active = dm
|
||||
.get_progress(&download_id)
|
||||
.is_some_and(|p| p.status == goose::download_manager::DownloadStatus::Downloading);
|
||||
if !dominated_by_active {
|
||||
tracing::info!(model_id = %model_id, "Auto-downloading vision encoder for existing model");
|
||||
if let Err(e) = dm.download_model(download_id, url, path, None).await {
|
||||
tracing::warn!(model_id = %model_id, error = %e, "Failed to start mmproj download");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
@ -154,6 +225,28 @@ pub async fn list_local_models(
|
|||
|
||||
let size_bytes = entry.file_size();
|
||||
|
||||
let vision_capable = entry.settings.vision_capable;
|
||||
let mmproj_status = if vision_capable {
|
||||
let ms = entry.mmproj_download_status();
|
||||
Some(match ms {
|
||||
RegistryDownloadStatus::NotDownloaded => ModelDownloadStatus::NotDownloaded,
|
||||
RegistryDownloadStatus::Downloading {
|
||||
progress_percent,
|
||||
bytes_downloaded,
|
||||
total_bytes,
|
||||
speed_bps,
|
||||
} => ModelDownloadStatus::Downloading {
|
||||
progress_percent,
|
||||
bytes_downloaded,
|
||||
total_bytes,
|
||||
speed_bps: Some(speed_bps),
|
||||
},
|
||||
RegistryDownloadStatus::Downloaded => ModelDownloadStatus::Downloaded,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
models.push(LocalModelResponse {
|
||||
id: entry.id.clone(),
|
||||
repo_id: entry.repo_id.clone(),
|
||||
|
|
@ -163,6 +256,8 @@ pub async fn list_local_models(
|
|||
status,
|
||||
recommended: recommended_id == entry.id,
|
||||
settings: entry.settings.clone(),
|
||||
vision_capable,
|
||||
mmproj_status,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -276,16 +371,26 @@ pub async fn download_hf_model(
|
|||
source_url: download_url.clone(),
|
||||
settings: default_settings_for_model(&model_id),
|
||||
size_bytes: hf_file.size_bytes,
|
||||
mmproj_path: None,
|
||||
mmproj_source_url: None,
|
||||
mmproj_size_bytes: 0,
|
||||
};
|
||||
|
||||
{
|
||||
// add_model enriches the entry with mmproj metadata from the featured table
|
||||
let mmproj_path = {
|
||||
let mut registry = get_registry()
|
||||
.lock()
|
||||
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
||||
registry
|
||||
.add_model(entry)
|
||||
.map_err(|e| ErrorResponse::internal(format!("{}", e)))?;
|
||||
}
|
||||
registry.get_model(&model_id).and_then(|e| {
|
||||
e.mmproj_path
|
||||
.as_ref()
|
||||
.zip(e.mmproj_source_url.as_ref())
|
||||
.map(|(p, u)| (p.clone(), u.clone()))
|
||||
})
|
||||
};
|
||||
|
||||
let dm = get_download_manager();
|
||||
dm.download_model(
|
||||
|
|
@ -297,6 +402,19 @@ pub async fn download_hf_model(
|
|||
.await
|
||||
.map_err(|e| ErrorResponse::internal(format!("Download failed: {}", e)))?;
|
||||
|
||||
if let Some((mmproj_path, mmproj_url)) = mmproj_path {
|
||||
if !mmproj_path.exists() {
|
||||
dm.download_model(
|
||||
format!("{}-mmproj", model_id),
|
||||
mmproj_url,
|
||||
mmproj_path,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| ErrorResponse::internal(format!("mmproj download failed: {}", e)))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((StatusCode::ACCEPTED, Json(model_id)))
|
||||
}
|
||||
|
||||
|
|
@ -338,6 +456,7 @@ pub async fn cancel_local_model_download(
|
|||
manager
|
||||
.cancel_download(&format!("{}-model", model_id))
|
||||
.map_err(|e| ErrorResponse::internal(format!("{}", e)))?;
|
||||
let _ = manager.cancel_download(&format!("{}-mmproj", model_id));
|
||||
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
|
@ -351,14 +470,22 @@ pub async fn cancel_local_model_download(
|
|||
)
|
||||
)]
|
||||
pub async fn delete_local_model(Path(model_id): Path<String>) -> Result<StatusCode, ErrorResponse> {
|
||||
let local_path = {
|
||||
let (local_path, mmproj_path, other_uses_mmproj) = {
|
||||
let registry = get_registry()
|
||||
.lock()
|
||||
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
|
||||
let entry = registry
|
||||
.get_model(&model_id)
|
||||
.ok_or_else(|| ErrorResponse::not_found("Model not found"))?;
|
||||
entry.local_path.clone()
|
||||
let lp = entry.local_path.clone();
|
||||
let mp = entry.mmproj_path.clone();
|
||||
// Check if another downloaded model shares this mmproj file
|
||||
let shared = mp.as_ref().is_some_and(|target| {
|
||||
registry.list_models().iter().any(|m| {
|
||||
m.id != model_id && m.is_downloaded() && m.mmproj_path.as_ref() == Some(target)
|
||||
})
|
||||
});
|
||||
(lp, mp, shared)
|
||||
};
|
||||
|
||||
if local_path.exists() {
|
||||
|
|
@ -367,6 +494,14 @@ pub async fn delete_local_model(Path(model_id): Path<String>) -> Result<StatusCo
|
|||
.map_err(|e| ErrorResponse::internal(format!("Failed to delete: {}", e)))?;
|
||||
}
|
||||
|
||||
if !other_uses_mmproj {
|
||||
if let Some(mmproj) = mmproj_path {
|
||||
if mmproj.exists() {
|
||||
let _ = tokio::fs::remove_file(&mmproj).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only remove non-featured models from registry (featured ones stay as placeholders)
|
||||
if !is_featured_model(&model_id) {
|
||||
let mut registry = get_registry()
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ tree-sitter-typescript = { workspace = true }
|
|||
which = { workspace = true }
|
||||
pctx_code_mode = { version = "^0.3.0", optional = true }
|
||||
pulldown-cmark = "0.13.0"
|
||||
llama-cpp-2 = { version = "0.1.143", features = ["sampler"], optional = true }
|
||||
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "mtmd"], optional = true }
|
||||
encoding_rs = "0.8.35"
|
||||
pastey = "0.2.1"
|
||||
shell-words = { workspace = true }
|
||||
|
|
@ -197,7 +197,7 @@ keyring = { version = "3.6.2", features = ["windows-native"] }
|
|||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { version = "0.9", default-features = false, features = ["metal"], optional = true }
|
||||
candle-nn = { version = "0.9", default-features = false, features = ["metal"], optional = true }
|
||||
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "metal"], optional = true }
|
||||
llama-cpp-2 = { version = "0.1.143", features = ["sampler", "metal", "mtmd"], optional = true }
|
||||
keyring = { version = "3.6.2", features = ["apple-native"] }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ mod inference_emulated_tools;
|
|||
mod inference_engine;
|
||||
mod inference_native_tools;
|
||||
pub mod local_model_registry;
|
||||
pub(crate) mod multimodal;
|
||||
mod tool_parsing;
|
||||
|
||||
use inference_emulated_tools::{
|
||||
|
|
@ -30,6 +31,7 @@ use llama_cpp_2::llama_backend::LlamaBackend;
|
|||
use llama_cpp_2::model::params::LlamaModelParams;
|
||||
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel};
|
||||
use llama_cpp_2::{list_llama_ggml_backend_devices, LlamaBackendDeviceType, LogOptions};
|
||||
use multimodal::ExtractedImage;
|
||||
use rmcp::model::{Role, Tool};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -114,14 +116,15 @@ const DEFAULT_MODEL: &str = "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M";
|
|||
|
||||
pub const LOCAL_LLM_MODEL_CONFIG_KEY: &str = "LOCAL_LLM_MODEL";
|
||||
|
||||
/// Resolve model path, context limit, and settings for a model ID from the registry.
|
||||
pub fn resolve_model_path(
|
||||
model_id: &str,
|
||||
) -> Option<(
|
||||
PathBuf,
|
||||
usize,
|
||||
crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||
)> {
|
||||
pub struct ResolvedModelPaths {
|
||||
pub model_path: PathBuf,
|
||||
pub context_limit: usize,
|
||||
pub settings: crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||
pub mmproj_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
/// Resolve model path, context limit, settings, and mmproj path for a model ID from the registry.
|
||||
pub fn resolve_model_path(model_id: &str) -> Option<ResolvedModelPaths> {
|
||||
use crate::providers::local_inference::local_model_registry::{
|
||||
default_settings_for_model, get_registry,
|
||||
};
|
||||
|
|
@ -135,7 +138,15 @@ pub fn resolve_model_path(
|
|||
// recognized (or with a different quantization) still get the right behavior.
|
||||
let defaults = default_settings_for_model(model_id);
|
||||
settings.native_tool_calling = defaults.native_tool_calling;
|
||||
return Some((entry.local_path.clone(), ctx, settings));
|
||||
settings.vision_capable = defaults.vision_capable;
|
||||
settings.mmproj_size_bytes = entry.mmproj_size_bytes;
|
||||
let mmproj_path = entry.mmproj_path.as_ref().filter(|p| p.exists()).cloned();
|
||||
return Some(ResolvedModelPaths {
|
||||
model_path: entry.local_path.clone(),
|
||||
context_limit: ctx,
|
||||
settings,
|
||||
mmproj_path,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -208,9 +219,33 @@ fn build_openai_messages_json(system: &str, messages: &[Message]) -> String {
|
|||
|
||||
let mut arr: Vec<Value> = vec![json!({"role": "system", "content": system})];
|
||||
arr.extend(format_messages(messages, &ImageFormat::OpenAi));
|
||||
strip_image_parts_from_messages(&mut arr);
|
||||
serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string())
|
||||
}
|
||||
|
||||
/// Remove `image_url` content parts from OpenAI-format messages JSON, replacing
|
||||
/// each with a text note. This prevents an FFI crash in llama.cpp which does not
|
||||
/// accept `image_url` content-part types.
|
||||
fn strip_image_parts_from_messages(messages: &mut [Value]) {
|
||||
let mut stripped = false;
|
||||
for msg in messages.iter_mut() {
|
||||
if let Some(content) = msg.get_mut("content").and_then(|c| c.as_array_mut()) {
|
||||
for part in content.iter_mut() {
|
||||
if part.get("type").and_then(|t| t.as_str()) == Some("image_url") {
|
||||
*part = json!({
|
||||
"type": "text",
|
||||
"text": "[Image attached — image input is not supported with the currently selected model]"
|
||||
});
|
||||
stripped = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if stripped {
|
||||
tracing::warn!("Stripped image content parts from messages — vision encoder not available for this model");
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a message into plain text for the emulator path's chat history.
|
||||
///
|
||||
/// This is the emulator-path counterpart of [`format_messages`] used by the native
|
||||
|
|
@ -269,6 +304,12 @@ fn extract_text_content(msg: &Message) -> String {
|
|||
parts.push(format!("Command error: {}", e));
|
||||
}
|
||||
},
|
||||
MessageContent::Image(_) => {
|
||||
parts.push(
|
||||
"[Image attached — image input is not supported with the currently selected model]"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
|
@ -331,8 +372,9 @@ impl LocalInferenceProvider {
|
|||
model_id: &str,
|
||||
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||
) -> Result<LoadedModel, ProviderError> {
|
||||
let (model_path, _context_limit, _) = resolve_model_path(model_id)
|
||||
let resolved = resolve_model_path(model_id)
|
||||
.ok_or_else(|| ProviderError::ExecutionError(format!("Unknown model: {}", model_id)))?;
|
||||
let model_path = resolved.model_path;
|
||||
|
||||
if !model_path.exists() {
|
||||
return Err(ProviderError::ExecutionError(format!(
|
||||
|
|
@ -368,9 +410,49 @@ impl LocalInferenceProvider {
|
|||
}
|
||||
};
|
||||
|
||||
let mtmd_ctx = Self::init_mtmd_context(&model, &resolved.mmproj_path, settings);
|
||||
|
||||
tracing::info!(model_id = model_id, "Model loaded successfully");
|
||||
|
||||
Ok(LoadedModel { model, template })
|
||||
Ok(LoadedModel {
|
||||
model,
|
||||
template,
|
||||
mtmd_ctx,
|
||||
})
|
||||
}
|
||||
|
||||
fn init_mtmd_context(
|
||||
model: &LlamaModel,
|
||||
mmproj_path: &Option<PathBuf>,
|
||||
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||
) -> Option<llama_cpp_2::mtmd::MtmdContext> {
|
||||
use llama_cpp_2::mtmd::{MtmdContext, MtmdContextParams};
|
||||
|
||||
let mmproj_path = mmproj_path.as_ref().filter(|p| p.exists())?;
|
||||
|
||||
let params = MtmdContextParams {
|
||||
use_gpu: true,
|
||||
n_threads: settings
|
||||
.n_threads
|
||||
.unwrap_or_else(|| MtmdContextParams::default().n_threads),
|
||||
..MtmdContextParams::default()
|
||||
};
|
||||
|
||||
match MtmdContext::init_from_file(mmproj_path.to_str().unwrap_or_default(), model, ¶ms)
|
||||
{
|
||||
Ok(ctx) => {
|
||||
tracing::info!(
|
||||
vision = ctx.support_vision(),
|
||||
audio = ctx.support_audio(),
|
||||
"Multimodal context initialized"
|
||||
);
|
||||
Some(ctx)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to init multimodal context");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -453,13 +535,11 @@ impl Provider for LocalInferenceProvider {
|
|||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let (_model_path, model_context_limit, model_settings) =
|
||||
resolve_model_path(&model_config.model_name).ok_or_else(|| {
|
||||
ProviderError::ExecutionError(format!(
|
||||
"Model not found: {}",
|
||||
model_config.model_name
|
||||
))
|
||||
})?;
|
||||
let resolved = resolve_model_path(&model_config.model_name).ok_or_else(|| {
|
||||
ProviderError::ExecutionError(format!("Model not found: {}", model_config.model_name))
|
||||
})?;
|
||||
let model_context_limit = resolved.context_limit;
|
||||
let model_settings = resolved.settings;
|
||||
|
||||
// Ensure model is loaded — unload any other models first to free memory.
|
||||
{
|
||||
|
|
@ -503,6 +583,18 @@ impl Provider for LocalInferenceProvider {
|
|||
system.to_string()
|
||||
};
|
||||
|
||||
// Extract images for vision-capable models, replacing them with markers.
|
||||
// For non-vision models, leave messages unchanged (existing strip logic handles them).
|
||||
let has_vision = resolved.mmproj_path.is_some();
|
||||
let marker = llama_cpp_2::mtmd::mtmd_default_marker();
|
||||
let (images, vision_messages): (Vec<ExtractedImage>, Option<Vec<Message>>) = if has_vision {
|
||||
let (imgs, msgs) = multimodal::extract_images_from_messages(messages, marker);
|
||||
(imgs, Some(msgs))
|
||||
} else {
|
||||
(Vec::new(), None)
|
||||
};
|
||||
let effective_messages: &[Message] = vision_messages.as_deref().unwrap_or(messages);
|
||||
|
||||
// Build chat messages for the template
|
||||
let mut chat_messages =
|
||||
vec![
|
||||
|
|
@ -529,7 +621,7 @@ impl Provider for LocalInferenceProvider {
|
|||
})?];
|
||||
}
|
||||
|
||||
for msg in messages {
|
||||
for msg in effective_messages {
|
||||
let role = match msg.role {
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
|
|
@ -553,7 +645,10 @@ impl Provider for LocalInferenceProvider {
|
|||
};
|
||||
|
||||
let oai_messages_json = if model_settings.use_jinja || native_tool_calling {
|
||||
Some(build_openai_messages_json(&system_prompt, messages))
|
||||
Some(build_openai_messages_json(
|
||||
&system_prompt,
|
||||
effective_messages,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
|
@ -563,6 +658,7 @@ impl Provider for LocalInferenceProvider {
|
|||
let model_name = model_config.model_name.clone();
|
||||
let context_limit = model_context_limit;
|
||||
let settings = model_settings;
|
||||
let mmproj_path = resolved.mmproj_path.clone();
|
||||
|
||||
let log_payload = serde_json::json!({
|
||||
"system": &system_prompt,
|
||||
|
|
@ -604,8 +700,8 @@ impl Provider for LocalInferenceProvider {
|
|||
}};
|
||||
}
|
||||
|
||||
let model_guard = model_arc.blocking_lock();
|
||||
let loaded = match model_guard.as_ref() {
|
||||
let mut model_guard = model_arc.blocking_lock();
|
||||
let loaded = match model_guard.as_mut() {
|
||||
Some(l) => l,
|
||||
None => {
|
||||
send_err!(ProviderError::ExecutionError(
|
||||
|
|
@ -614,6 +710,16 @@ impl Provider for LocalInferenceProvider {
|
|||
}
|
||||
};
|
||||
|
||||
// Lazily initialize the multimodal context if the vision encoder
|
||||
// was downloaded after the model was loaded.
|
||||
if !images.is_empty() && loaded.mtmd_ctx.is_none() {
|
||||
loaded.mtmd_ctx = LocalInferenceProvider::init_mtmd_context(
|
||||
&loaded.model,
|
||||
&mmproj_path,
|
||||
&settings,
|
||||
);
|
||||
}
|
||||
|
||||
let message_id = Uuid::new_v4().to_string();
|
||||
|
||||
let mut gen_ctx = GenerationContext {
|
||||
|
|
@ -626,6 +732,7 @@ impl Provider for LocalInferenceProvider {
|
|||
message_id: &message_id,
|
||||
tx: &tx,
|
||||
log: &mut log,
|
||||
images: &images,
|
||||
};
|
||||
|
||||
let result = if use_emulator {
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ use std::borrow::Cow;
|
|||
use uuid::Uuid;
|
||||
|
||||
use super::inference_engine::{
|
||||
create_and_prefill_context, generation_loop, validate_and_compute_context, GenerationContext,
|
||||
TokenAction,
|
||||
create_and_prefill_context, create_and_prefill_multimodal, generation_loop,
|
||||
validate_and_compute_context, GenerationContext, TokenAction,
|
||||
};
|
||||
use super::{finalize_usage, StreamSender, CODE_EXECUTION_TOOL, SHELL_TOOL};
|
||||
|
||||
|
|
@ -370,26 +370,32 @@ pub(super) fn generate_with_emulated_tools(
|
|||
ProviderError::ExecutionError(format!("Failed to apply chat template: {}", e))
|
||||
})?;
|
||||
|
||||
let tokens = ctx
|
||||
.loaded
|
||||
.model
|
||||
.str_to_token(&prompt, AddBos::Never)
|
||||
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let (prompt_token_count, effective_ctx) = validate_and_compute_context(
|
||||
ctx.loaded,
|
||||
ctx.runtime,
|
||||
tokens.len(),
|
||||
ctx.context_limit,
|
||||
ctx.settings,
|
||||
)?;
|
||||
let mut llama_ctx = create_and_prefill_context(
|
||||
ctx.loaded,
|
||||
ctx.runtime,
|
||||
&tokens,
|
||||
effective_ctx,
|
||||
ctx.settings,
|
||||
)?;
|
||||
let (mut llama_ctx, prompt_token_count, effective_ctx) = if !ctx.images.is_empty() {
|
||||
create_and_prefill_multimodal(
|
||||
ctx.loaded,
|
||||
ctx.runtime,
|
||||
&prompt,
|
||||
ctx.images,
|
||||
ctx.context_limit,
|
||||
ctx.settings,
|
||||
)?
|
||||
} else {
|
||||
let tokens = ctx
|
||||
.loaded
|
||||
.model
|
||||
.str_to_token(&prompt, AddBos::Never)
|
||||
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
|
||||
let (ptc, ectx) = validate_and_compute_context(
|
||||
ctx.loaded,
|
||||
ctx.runtime,
|
||||
tokens.len(),
|
||||
ctx.context_limit,
|
||||
ctx.settings,
|
||||
)?;
|
||||
let lctx =
|
||||
create_and_prefill_context(ctx.loaded, ctx.runtime, &tokens, ectx, ctx.settings)?;
|
||||
(lctx, ptc, ectx)
|
||||
};
|
||||
|
||||
let message_id = ctx.message_id;
|
||||
let tx = ctx.tx;
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::local_inference::local_model_registry::ModelSettings;
|
||||
use crate::providers::local_inference::multimodal::ExtractedImage;
|
||||
use crate::providers::utils::RequestLog;
|
||||
use llama_cpp_2::context::params::LlamaContextParams;
|
||||
use llama_cpp_2::llama_batch::LlamaBatch;
|
||||
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel};
|
||||
use llama_cpp_2::mtmd::{MtmdBitmap, MtmdContext, MtmdInputText};
|
||||
use llama_cpp_2::sampling::LlamaSampler;
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
|
|
@ -19,11 +21,14 @@ pub(super) struct GenerationContext<'a> {
|
|||
pub message_id: &'a str,
|
||||
pub tx: &'a StreamSender,
|
||||
pub log: &'a mut RequestLog,
|
||||
pub images: &'a [ExtractedImage],
|
||||
}
|
||||
|
||||
pub(super) struct LoadedModel {
|
||||
pub model: LlamaModel,
|
||||
pub template: LlamaChatTemplate,
|
||||
/// Multimodal context for vision models. None for text-only models.
|
||||
pub mtmd_ctx: Option<MtmdContext>,
|
||||
}
|
||||
|
||||
/// Estimate the maximum context length that can fit in available accelerator/CPU
|
||||
|
|
@ -33,11 +38,13 @@ pub(super) struct LoadedModel {
|
|||
pub(super) fn estimate_max_context_for_memory(
|
||||
model: &LlamaModel,
|
||||
runtime: &InferenceRuntime,
|
||||
mmproj_overhead_bytes: u64,
|
||||
) -> Option<usize> {
|
||||
let available = super::available_inference_memory_bytes(runtime);
|
||||
if available == 0 {
|
||||
let raw_available = super::available_inference_memory_bytes(runtime);
|
||||
if raw_available == 0 {
|
||||
return None;
|
||||
}
|
||||
let available = raw_available.saturating_sub(mmproj_overhead_bytes);
|
||||
|
||||
// Reserve memory for computation scratch buffers (attention, etc.) and other overhead.
|
||||
// The compute buffer can be 40-50% of the KV cache size for large models, so we
|
||||
|
|
@ -209,7 +216,12 @@ pub(super) fn validate_and_compute_context(
|
|||
settings: &crate::providers::local_inference::local_model_registry::ModelSettings,
|
||||
) -> Result<(usize, usize), ProviderError> {
|
||||
let n_ctx_train = loaded.model.n_ctx_train() as usize;
|
||||
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime);
|
||||
let mmproj_overhead = if loaded.mtmd_ctx.is_some() {
|
||||
settings.mmproj_size_bytes
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime, mmproj_overhead);
|
||||
let effective_ctx = effective_context_size(
|
||||
prompt_token_count,
|
||||
settings,
|
||||
|
|
@ -261,6 +273,80 @@ pub(super) fn create_and_prefill_context<'model>(
|
|||
Ok(ctx)
|
||||
}
|
||||
|
||||
/// Tokenize text + images via mtmd and prefill the context.
|
||||
///
|
||||
/// Returns the llama context, the number of prompt tokens consumed,
|
||||
/// and the effective context size.
|
||||
pub(super) fn create_and_prefill_multimodal<'model>(
|
||||
loaded: &'model LoadedModel,
|
||||
runtime: &InferenceRuntime,
|
||||
prompt_text: &str,
|
||||
images: &[ExtractedImage],
|
||||
context_limit: usize,
|
||||
settings: &ModelSettings,
|
||||
) -> Result<(llama_cpp_2::context::LlamaContext<'model>, usize, usize), ProviderError> {
|
||||
let mtmd_ctx = loaded.mtmd_ctx.as_ref().ok_or_else(|| {
|
||||
ProviderError::ExecutionError(
|
||||
"This model does not have vision support. Download the vision encoder from \
|
||||
Settings > Local Inference, or use a text-only message."
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let bitmaps: Vec<MtmdBitmap> = images
|
||||
.iter()
|
||||
.map(|img| {
|
||||
MtmdBitmap::from_buffer(mtmd_ctx, &img.bytes)
|
||||
.map_err(|e| ProviderError::ExecutionError(format!("Failed to decode image: {e}")))
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let bitmap_refs: Vec<&MtmdBitmap> = bitmaps.iter().collect();
|
||||
|
||||
let input_text = MtmdInputText {
|
||||
text: prompt_text.to_string(),
|
||||
add_special: true,
|
||||
parse_special: true,
|
||||
};
|
||||
let chunks = mtmd_ctx.tokenize(input_text, &bitmap_refs).map_err(|e| {
|
||||
ProviderError::ExecutionError(format!("Multimodal tokenization failed: {e}"))
|
||||
})?;
|
||||
|
||||
let prompt_token_count = chunks.total_tokens();
|
||||
|
||||
let n_ctx_train = loaded.model.n_ctx_train() as usize;
|
||||
let mmproj_overhead = settings.mmproj_size_bytes;
|
||||
let memory_max_ctx = estimate_max_context_for_memory(&loaded.model, runtime, mmproj_overhead);
|
||||
let effective_ctx = effective_context_size(
|
||||
prompt_token_count,
|
||||
settings,
|
||||
context_limit,
|
||||
n_ctx_train,
|
||||
memory_max_ctx,
|
||||
);
|
||||
|
||||
let min_generation_headroom = 512;
|
||||
if prompt_token_count + min_generation_headroom > effective_ctx {
|
||||
return Err(ProviderError::ContextLengthExceeded(format!(
|
||||
"Multimodal prompt ({prompt_token_count} tokens including images) exceeds \
|
||||
context limit ({effective_ctx} tokens)",
|
||||
)));
|
||||
}
|
||||
|
||||
let ctx_params = build_context_params(effective_ctx as u32, settings);
|
||||
let llama_ctx = loaded
|
||||
.model
|
||||
.new_context(runtime.backend(), ctx_params)
|
||||
.map_err(|e| ProviderError::ExecutionError(format!("Failed to create context: {e}")))?;
|
||||
|
||||
let n_batch = llama_ctx.n_batch() as i32;
|
||||
let _n_past = chunks
|
||||
.eval_chunks(mtmd_ctx, &llama_ctx, 0, 0, n_batch, true)
|
||||
.map_err(|e| ProviderError::ExecutionError(format!("Multimodal eval failed: {e}")))?;
|
||||
|
||||
Ok((llama_ctx, prompt_token_count, effective_ctx))
|
||||
}
|
||||
|
||||
/// Action to take after processing a generated token piece.
|
||||
pub(super) enum TokenAction {
|
||||
Continue,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,9 @@ use uuid::Uuid;
|
|||
|
||||
use super::finalize_usage;
|
||||
use super::inference_engine::{
|
||||
context_cap, create_and_prefill_context, estimate_max_context_for_memory, generation_loop,
|
||||
validate_and_compute_context, GenerationContext, TokenAction,
|
||||
context_cap, create_and_prefill_context, create_and_prefill_multimodal,
|
||||
estimate_max_context_for_memory, generation_loop, validate_and_compute_context,
|
||||
GenerationContext, TokenAction,
|
||||
};
|
||||
|
||||
pub(super) fn generate_with_native_tools(
|
||||
|
|
@ -21,7 +22,13 @@ pub(super) fn generate_with_native_tools(
|
|||
) -> Result<(), ProviderError> {
|
||||
let min_generation_headroom = 512;
|
||||
let n_ctx_train = ctx.loaded.model.n_ctx_train() as usize;
|
||||
let memory_max_ctx = estimate_max_context_for_memory(&ctx.loaded.model, ctx.runtime);
|
||||
let mmproj_overhead = if ctx.loaded.mtmd_ctx.is_some() {
|
||||
ctx.settings.mmproj_size_bytes
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let memory_max_ctx =
|
||||
estimate_max_context_for_memory(&ctx.loaded.model, ctx.runtime, mmproj_overhead);
|
||||
let cap = context_cap(ctx.settings, ctx.context_limit, n_ctx_train, memory_max_ctx);
|
||||
let token_budget = cap.saturating_sub(min_generation_headroom);
|
||||
|
||||
|
|
@ -61,6 +68,8 @@ pub(super) fn generate_with_native_tools(
|
|||
}
|
||||
};
|
||||
|
||||
let estimated_image_tokens = ctx.images.len() * ctx.settings.image_token_estimate;
|
||||
|
||||
let template_result = match apply_template(full_tools_json) {
|
||||
Ok(r) => {
|
||||
let token_count = ctx
|
||||
|
|
@ -69,7 +78,7 @@ pub(super) fn generate_with_native_tools(
|
|||
.str_to_token(&r.prompt, AddBos::Never)
|
||||
.map(|t| t.len())
|
||||
.unwrap_or(0);
|
||||
if token_count > token_budget {
|
||||
if token_count + estimated_image_tokens > token_budget {
|
||||
apply_template(compact_tools).unwrap_or(r)
|
||||
} else {
|
||||
r
|
||||
|
|
@ -85,26 +94,32 @@ pub(super) fn generate_with_native_tools(
|
|||
None,
|
||||
);
|
||||
|
||||
let tokens = ctx
|
||||
.loaded
|
||||
.model
|
||||
.str_to_token(&template_result.prompt, AddBos::Never)
|
||||
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let (prompt_token_count, effective_ctx) = validate_and_compute_context(
|
||||
ctx.loaded,
|
||||
ctx.runtime,
|
||||
tokens.len(),
|
||||
ctx.context_limit,
|
||||
ctx.settings,
|
||||
)?;
|
||||
let mut llama_ctx = create_and_prefill_context(
|
||||
ctx.loaded,
|
||||
ctx.runtime,
|
||||
&tokens,
|
||||
effective_ctx,
|
||||
ctx.settings,
|
||||
)?;
|
||||
let (mut llama_ctx, prompt_token_count, effective_ctx) = if !ctx.images.is_empty() {
|
||||
create_and_prefill_multimodal(
|
||||
ctx.loaded,
|
||||
ctx.runtime,
|
||||
&template_result.prompt,
|
||||
ctx.images,
|
||||
ctx.context_limit,
|
||||
ctx.settings,
|
||||
)?
|
||||
} else {
|
||||
let tokens = ctx
|
||||
.loaded
|
||||
.model
|
||||
.str_to_token(&template_result.prompt, AddBos::Never)
|
||||
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
|
||||
let (ptc, ectx) = validate_and_compute_context(
|
||||
ctx.loaded,
|
||||
ctx.runtime,
|
||||
tokens.len(),
|
||||
ctx.context_limit,
|
||||
ctx.settings,
|
||||
)?;
|
||||
let lctx =
|
||||
create_and_prefill_context(ctx.loaded, ctx.runtime, &tokens, ectx, ctx.settings)?;
|
||||
(lctx, ptc, ectx)
|
||||
};
|
||||
|
||||
let message_id = ctx.message_id;
|
||||
let tx = ctx.tx;
|
||||
|
|
|
|||
|
|
@ -62,12 +62,27 @@ pub struct ModelSettings {
|
|||
pub use_jinja: bool,
|
||||
#[serde(default = "default_true")]
|
||||
pub enable_thinking: bool,
|
||||
/// Whether this model architecture supports vision input.
|
||||
/// Derived from the featured model table, not user-configurable.
|
||||
#[serde(default)]
|
||||
pub vision_capable: bool,
|
||||
/// Estimated tokens per image for budget planning before mtmd tokenization.
|
||||
/// The actual count is determined after tokenization via `chunks.total_tokens()`.
|
||||
#[serde(default = "default_image_token_estimate")]
|
||||
pub image_token_estimate: usize,
|
||||
/// Size of the mmproj file in bytes, used for memory accounting.
|
||||
#[serde(default)]
|
||||
pub mmproj_size_bytes: u64,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_image_token_estimate() -> usize {
|
||||
256
|
||||
}
|
||||
|
||||
fn default_repeat_penalty() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
|
@ -94,41 +109,75 @@ impl Default for ModelSettings {
|
|||
native_tool_calling: false,
|
||||
use_jinja: false,
|
||||
enable_thinking: true,
|
||||
vision_capable: false,
|
||||
image_token_estimate: default_image_token_estimate(),
|
||||
mmproj_size_bytes: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HuggingFace repo + filename for multimodal projection weights (vision encoder).
|
||||
pub struct MmprojSpec {
|
||||
pub repo: &'static str,
|
||||
pub filename: &'static str,
|
||||
}
|
||||
|
||||
impl MmprojSpec {
|
||||
/// Local path for this mmproj, namespaced by repo to avoid collisions
|
||||
/// between different models that use the same filename.
|
||||
pub fn local_path(&self) -> std::path::PathBuf {
|
||||
let repo_name = self.repo.split('/').next_back().unwrap_or(self.repo);
|
||||
Paths::in_data_dir("models")
|
||||
.join(repo_name)
|
||||
.join(self.filename)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FeaturedModel {
|
||||
/// HuggingFace spec in "author/repo-GGUF:quantization" format.
|
||||
pub spec: &'static str,
|
||||
/// Whether this model's GGUF template supports native tool calling via llama.cpp.
|
||||
pub native_tool_calling: bool,
|
||||
/// Multimodal projection weights spec. None for text-only models.
|
||||
pub mmproj: Option<MmprojSpec>,
|
||||
}
|
||||
|
||||
pub const FEATURED_MODELS: &[FeaturedModel] = &[
|
||||
FeaturedModel {
|
||||
spec: "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M",
|
||||
native_tool_calling: false,
|
||||
mmproj: None,
|
||||
},
|
||||
FeaturedModel {
|
||||
spec: "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M",
|
||||
native_tool_calling: false,
|
||||
mmproj: None,
|
||||
},
|
||||
FeaturedModel {
|
||||
spec: "bartowski/Hermes-2-Pro-Mistral-7B-GGUF:Q4_K_M",
|
||||
native_tool_calling: false,
|
||||
mmproj: None,
|
||||
},
|
||||
FeaturedModel {
|
||||
spec: "bartowski/Mistral-Small-24B-Instruct-2501-GGUF:Q4_K_M",
|
||||
native_tool_calling: false,
|
||||
mmproj: None,
|
||||
},
|
||||
FeaturedModel {
|
||||
spec: "unsloth/gemma-4-E4B-it-GGUF:Q4_K_M",
|
||||
native_tool_calling: true,
|
||||
mmproj: Some(MmprojSpec {
|
||||
repo: "unsloth/gemma-4-E4B-it-GGUF",
|
||||
filename: "mmproj-BF16.gguf",
|
||||
}),
|
||||
},
|
||||
FeaturedModel {
|
||||
spec: "unsloth/gemma-4-26B-A4B-it-GGUF:Q4_K_M",
|
||||
native_tool_calling: true,
|
||||
mmproj: Some(MmprojSpec {
|
||||
repo: "unsloth/gemma-4-26B-A4B-it-GGUF",
|
||||
filename: "mmproj-BF16.gguf",
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
|
|
@ -144,10 +193,25 @@ pub fn default_settings_for_model(model_id: &str) -> ModelSettings {
|
|||
});
|
||||
ModelSettings {
|
||||
native_tool_calling: featured.is_some_and(|m| m.native_tool_calling),
|
||||
vision_capable: featured.is_some_and(|m| m.mmproj.is_some()),
|
||||
..ModelSettings::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up the `MmprojSpec` for a featured model by its model ID.
|
||||
pub fn featured_mmproj_spec(model_id: &str) -> Option<&'static MmprojSpec> {
|
||||
use super::hf_models::parse_model_spec;
|
||||
let model_repo = model_id.split(':').next().unwrap_or(model_id);
|
||||
FEATURED_MODELS.iter().find_map(|m| {
|
||||
if let Ok((repo_id, _quant)) = parse_model_spec(m.spec) {
|
||||
if repo_id == model_repo {
|
||||
return m.mmproj.as_ref();
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a model ID corresponds to a featured model.
|
||||
pub fn is_featured_model(model_id: &str) -> bool {
|
||||
use super::hf_models::parse_model_spec;
|
||||
|
|
@ -181,9 +245,42 @@ pub struct LocalModelEntry {
|
|||
pub settings: ModelSettings,
|
||||
#[serde(default)]
|
||||
pub size_bytes: u64,
|
||||
/// Local path to the multimodal projection GGUF (vision encoder).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub mmproj_path: Option<PathBuf>,
|
||||
/// Download URL for the mmproj file.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub mmproj_source_url: Option<String>,
|
||||
/// Size of the mmproj file in bytes.
|
||||
#[serde(default)]
|
||||
pub mmproj_size_bytes: u64,
|
||||
}
|
||||
|
||||
impl LocalModelEntry {
|
||||
/// Populate mmproj metadata and vision settings from the featured model
|
||||
/// table if this model's repo has a known vision encoder.
|
||||
pub fn enrich_with_featured_mmproj(&mut self) {
|
||||
if let Some(mmproj) = featured_mmproj_spec(&self.id) {
|
||||
let path = mmproj.local_path();
|
||||
if self.mmproj_path.as_ref() != Some(&path) {
|
||||
self.mmproj_path = Some(path.clone());
|
||||
self.mmproj_source_url = Some(format!(
|
||||
"https://huggingface.co/{}/resolve/main/{}",
|
||||
mmproj.repo, mmproj.filename
|
||||
));
|
||||
}
|
||||
self.settings.vision_capable = true;
|
||||
if self.mmproj_size_bytes == 0 || self.settings.mmproj_size_bytes == 0 {
|
||||
if let Ok(meta) = std::fs::metadata(&path) {
|
||||
self.mmproj_size_bytes = meta.len();
|
||||
self.settings.mmproj_size_bytes = meta.len();
|
||||
}
|
||||
}
|
||||
}
|
||||
let defaults = default_settings_for_model(&self.id);
|
||||
self.settings.native_tool_calling = defaults.native_tool_calling;
|
||||
}
|
||||
|
||||
pub fn is_downloaded(&self) -> bool {
|
||||
self.local_path.exists()
|
||||
}
|
||||
|
|
@ -219,6 +316,36 @@ impl LocalModelEntry {
|
|||
ModelDownloadStatus::NotDownloaded
|
||||
}
|
||||
|
||||
pub fn has_vision(&self) -> bool {
|
||||
self.mmproj_path.as_ref().is_some_and(|p| p.exists())
|
||||
}
|
||||
|
||||
pub fn mmproj_download_status(&self) -> ModelDownloadStatus {
|
||||
if let Some(path) = &self.mmproj_path {
|
||||
if path.exists() {
|
||||
return ModelDownloadStatus::Downloaded;
|
||||
}
|
||||
} else {
|
||||
return ModelDownloadStatus::NotDownloaded;
|
||||
}
|
||||
|
||||
let download_id = format!("{}-mmproj", self.id);
|
||||
let manager = get_download_manager();
|
||||
if let Some(progress) = manager.get_progress(&download_id) {
|
||||
return match progress.status {
|
||||
DownloadStatus::Downloading => ModelDownloadStatus::Downloading {
|
||||
progress_percent: progress.progress_percent,
|
||||
bytes_downloaded: progress.bytes_downloaded,
|
||||
total_bytes: progress.total_bytes,
|
||||
speed_bps: progress.speed_bps.unwrap_or(0),
|
||||
},
|
||||
_ => ModelDownloadStatus::NotDownloaded,
|
||||
};
|
||||
}
|
||||
|
||||
ModelDownloadStatus::NotDownloaded
|
||||
}
|
||||
|
||||
pub fn file_size(&self) -> u64 {
|
||||
if self.size_bytes > 0 {
|
||||
return self.size_bytes;
|
||||
|
|
@ -290,8 +417,9 @@ impl LocalModelRegistry {
|
|||
pub fn sync_with_featured(&mut self, featured_entries: Vec<LocalModelEntry>) {
|
||||
let mut changed = false;
|
||||
|
||||
for entry in featured_entries {
|
||||
for mut entry in featured_entries {
|
||||
if !self.models.iter().any(|m| m.id == entry.id) {
|
||||
entry.enrich_with_featured_mmproj();
|
||||
self.models.push(entry);
|
||||
changed = true;
|
||||
}
|
||||
|
|
@ -309,7 +437,8 @@ impl LocalModelRegistry {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn add_model(&mut self, entry: LocalModelEntry) -> Result<()> {
|
||||
pub fn add_model(&mut self, mut entry: LocalModelEntry) -> Result<()> {
|
||||
entry.enrich_with_featured_mmproj();
|
||||
if let Some(existing) = self.models.iter_mut().find(|m| m.id == entry.id) {
|
||||
*existing = entry;
|
||||
} else {
|
||||
|
|
@ -348,6 +477,10 @@ impl LocalModelRegistry {
|
|||
pub fn list_models(&self) -> &[LocalModelEntry] {
|
||||
&self.models
|
||||
}
|
||||
|
||||
pub fn list_models_mut(&mut self) -> &mut [LocalModelEntry] {
|
||||
&mut self.models
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a unique ID for a model from its repo_id and quantization.
|
||||
|
|
|
|||
336
crates/goose/src/providers/local_inference/multimodal.rs
Normal file
336
crates/goose/src/providers/local_inference/multimodal.rs
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
use base64::prelude::*;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::conversation::message::{Message, MessageContent};
|
||||
use crate::providers::errors::ProviderError;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ExtractedImage {
|
||||
pub bytes: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub struct MultimodalMessages {
|
||||
pub messages_json: String,
|
||||
pub images: Vec<ExtractedImage>,
|
||||
}
|
||||
|
||||
/// Walk the OpenAI-format messages JSON array. For each content part with
|
||||
/// `type: "image_url"`, decode the base64 data URL, store the raw bytes,
|
||||
/// and replace the part with `{"type": "text", "text": "<marker>"}`.
|
||||
///
|
||||
/// Returns the modified JSON string and the extracted images in order.
|
||||
#[allow(dead_code)]
|
||||
pub fn extract_images_from_messages_json(
|
||||
messages_json: &str,
|
||||
marker: &str,
|
||||
) -> Result<MultimodalMessages, ProviderError> {
|
||||
let mut messages: Vec<Value> = serde_json::from_str(messages_json).map_err(|e| {
|
||||
ProviderError::ExecutionError(format!("Failed to parse messages JSON: {e}"))
|
||||
})?;
|
||||
|
||||
let mut images = Vec::new();
|
||||
|
||||
for msg in messages.iter_mut() {
|
||||
let Some(content) = msg.get_mut("content").and_then(|c| c.as_array_mut()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for part in content.iter_mut() {
|
||||
if part.get("type").and_then(|t| t.as_str()) != Some("image_url") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let url = part
|
||||
.get("image_url")
|
||||
.and_then(|obj| obj.get("url"))
|
||||
.and_then(|u| u.as_str())
|
||||
.unwrap_or_default();
|
||||
|
||||
if url.starts_with("http://") || url.starts_with("https://") {
|
||||
return Err(ProviderError::ExecutionError(
|
||||
"Remote image URLs are not supported with local inference. \
|
||||
Please attach the image directly."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let base64_data = url.split_once(',').map_or(url, |(_, data)| data);
|
||||
|
||||
let bytes = BASE64_STANDARD.decode(base64_data).map_err(|e| {
|
||||
ProviderError::ExecutionError(format!("Failed to decode base64 image: {e}"))
|
||||
})?;
|
||||
|
||||
images.push(ExtractedImage { bytes });
|
||||
|
||||
*part = serde_json::json!({
|
||||
"type": "text",
|
||||
"text": marker,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let messages_json = serde_json::to_string(&messages)
|
||||
.map_err(|e| ProviderError::ExecutionError(format!("Failed to serialize messages: {e}")))?;
|
||||
|
||||
Ok(MultimodalMessages {
|
||||
messages_json,
|
||||
images,
|
||||
})
|
||||
}
|
||||
|
||||
/// Scan messages for `MessageContent::Image` entries. Return the extracted image
|
||||
/// bytes and a new message list with images replaced by text marker placeholders.
|
||||
pub fn extract_images_from_messages(
|
||||
messages: &[Message],
|
||||
marker: &str,
|
||||
) -> (Vec<ExtractedImage>, Vec<Message>) {
|
||||
let mut images = Vec::new();
|
||||
let mut new_messages = Vec::with_capacity(messages.len());
|
||||
|
||||
for msg in messages {
|
||||
let mut new_content = Vec::with_capacity(msg.content.len());
|
||||
for content in &msg.content {
|
||||
match content {
|
||||
MessageContent::Image(img) => {
|
||||
if let Ok(bytes) = BASE64_STANDARD.decode(&img.data) {
|
||||
images.push(ExtractedImage { bytes });
|
||||
new_content.push(MessageContent::text(marker));
|
||||
} else {
|
||||
new_content.push(MessageContent::text(
|
||||
"[Image attached — failed to decode image data]",
|
||||
));
|
||||
}
|
||||
}
|
||||
other => new_content.push(other.clone()),
|
||||
}
|
||||
}
|
||||
new_messages.push(Message {
|
||||
role: msg.role.clone(),
|
||||
content: new_content,
|
||||
..msg.clone()
|
||||
});
|
||||
}
|
||||
|
||||
(images, new_messages)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn make_test_messages_json(parts: Vec<Value>) -> String {
|
||||
serde_json::to_string(&vec![json!({
|
||||
"role": "user",
|
||||
"content": parts,
|
||||
})])
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn tiny_png_base64() -> String {
|
||||
// 1x1 red PNG
|
||||
let bytes: &[u8] = &[
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48,
|
||||
0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00,
|
||||
0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08,
|
||||
0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x36, 0x28, 0x19,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
|
||||
];
|
||||
BASE64_STANDARD.encode(bytes)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_images_replaces_image_url_with_marker() {
|
||||
let b64 = tiny_png_base64();
|
||||
let json = make_test_messages_json(vec![json!({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": format!("data:image/png;base64,{b64}")}
|
||||
})]);
|
||||
|
||||
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||
assert_eq!(result.images.len(), 1);
|
||||
assert!(!result.images[0].bytes.is_empty());
|
||||
|
||||
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||
let content = parsed[0]["content"].as_array().unwrap();
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert_eq!(content[0]["text"], "<__media__>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_images_preserves_text_parts() {
|
||||
let json = make_test_messages_json(vec![json!({
|
||||
"type": "text",
|
||||
"text": "Hello world"
|
||||
})]);
|
||||
|
||||
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||
assert!(result.images.is_empty());
|
||||
|
||||
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||
let content = parsed[0]["content"].as_array().unwrap();
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert_eq!(content[0]["text"], "Hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_images_multiple_images() {
|
||||
let b64 = tiny_png_base64();
|
||||
let json = make_test_messages_json(vec![
|
||||
json!({"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}),
|
||||
json!({"type": "text", "text": "describe both"}),
|
||||
json!({"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}}),
|
||||
]);
|
||||
|
||||
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||
assert_eq!(result.images.len(), 2);
|
||||
|
||||
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||
let content = parsed[0]["content"].as_array().unwrap();
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert_eq!(content[0]["text"], "<__media__>");
|
||||
assert_eq!(content[1]["type"], "text");
|
||||
assert_eq!(content[1]["text"], "describe both");
|
||||
assert_eq!(content[2]["type"], "text");
|
||||
assert_eq!(content[2]["text"], "<__media__>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_images_no_images() {
|
||||
let json = make_test_messages_json(vec![json!({
|
||||
"type": "text",
|
||||
"text": "just text"
|
||||
})]);
|
||||
|
||||
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||
assert!(result.images.is_empty());
|
||||
// JSON should be equivalent
|
||||
let original: Vec<Value> = serde_json::from_str(&json).unwrap();
|
||||
let result_parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||
assert_eq!(original, result_parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_images_http_url_rejected() {
|
||||
let json = make_test_messages_json(vec![json!({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.png"}
|
||||
})]);
|
||||
|
||||
let result = extract_images_from_messages_json(&json, "<__media__>");
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("Remote image URLs are not supported"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_images_mixed_content() {
|
||||
let b64 = tiny_png_base64();
|
||||
// Two messages: first with text+image, second with just text
|
||||
let json = serde_json::to_string(&vec![
|
||||
json!({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is this?"},
|
||||
{"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{b64}")}},
|
||||
]
|
||||
}),
|
||||
json!({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "It looks like a red pixel."},
|
||||
]
|
||||
}),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let result = extract_images_from_messages_json(&json, "<__media__>").unwrap();
|
||||
assert_eq!(result.images.len(), 1);
|
||||
|
||||
let parsed: Vec<Value> = serde_json::from_str(&result.messages_json).unwrap();
|
||||
// First message: text preserved, image replaced
|
||||
let content0 = parsed[0]["content"].as_array().unwrap();
|
||||
assert_eq!(content0[0]["text"], "What is this?");
|
||||
assert_eq!(content0[1]["text"], "<__media__>");
|
||||
// Second message unchanged
|
||||
let content1 = parsed[1]["content"].as_array().unwrap();
|
||||
assert_eq!(content1[0]["text"], "It looks like a red pixel.");
|
||||
}
|
||||
|
||||
// --- Tests for extract_images_from_messages (Message-based) ---
|
||||
|
||||
#[test]
|
||||
fn test_messages_extract_replaces_image_with_marker() {
|
||||
let b64 = tiny_png_base64();
|
||||
let messages = vec![Message::user().with_image(b64, "image/png")];
|
||||
|
||||
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||
assert_eq!(images.len(), 1);
|
||||
assert!(!images[0].bytes.is_empty());
|
||||
assert_eq!(new_msgs.len(), 1);
|
||||
assert_eq!(new_msgs[0].as_concat_text(), "<__media__>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_messages_extract_preserves_text() {
|
||||
let messages = vec![Message::user().with_text("Hello world")];
|
||||
|
||||
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||
assert!(images.is_empty());
|
||||
assert_eq!(new_msgs[0].as_concat_text(), "Hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_messages_extract_multiple_images() {
|
||||
let b64 = tiny_png_base64();
|
||||
let messages = vec![Message::user()
|
||||
.with_image(b64.clone(), "image/png")
|
||||
.with_text("describe both")
|
||||
.with_image(b64, "image/png")];
|
||||
|
||||
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||
assert_eq!(images.len(), 2);
|
||||
assert_eq!(new_msgs[0].content.len(), 3);
|
||||
assert_eq!(
|
||||
new_msgs[0].as_concat_text(),
|
||||
"<__media__>\ndescribe both\n<__media__>"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_messages_extract_no_images() {
|
||||
let messages = vec![Message::user().with_text("just text")];
|
||||
|
||||
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||
assert!(images.is_empty());
|
||||
assert_eq!(new_msgs[0].as_concat_text(), "just text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_messages_extract_invalid_base64() {
|
||||
let messages = vec![Message::user().with_image("not-valid-base64!!!", "image/png")];
|
||||
|
||||
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||
assert!(images.is_empty());
|
||||
assert!(new_msgs[0].as_concat_text().contains("failed to decode"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_messages_extract_mixed_content() {
|
||||
let b64 = tiny_png_base64();
|
||||
let messages = vec![
|
||||
Message::user()
|
||||
.with_text("What is this?")
|
||||
.with_image(b64, "image/png"),
|
||||
Message::assistant().with_text("It looks like a red pixel."),
|
||||
];
|
||||
|
||||
let (images, new_msgs) = extract_images_from_messages(&messages, "<__media__>");
|
||||
assert_eq!(images.len(), 1);
|
||||
assert_eq!(new_msgs.len(), 2);
|
||||
assert_eq!(new_msgs[0].as_concat_text(), "What is this?\n<__media__>");
|
||||
assert_eq!(new_msgs[1].as_concat_text(), "It looks like a red pixel.");
|
||||
}
|
||||
}
|
||||
|
|
@ -9,7 +9,11 @@
|
|||
//!
|
||||
//! Run with a specific model:
|
||||
//! TEST_MODEL="bartowski/Qwen_Qwen3-32B-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration -- --ignored
|
||||
//!
|
||||
//! Run vision tests (requires a vision-capable model like gemma-4):
|
||||
//! TEST_VISION_MODEL="unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" cargo test -p goose --test local_inference_integration test_local_inference_vision -- --ignored
|
||||
|
||||
use base64::prelude::*;
|
||||
use futures::StreamExt;
|
||||
use goose::conversation::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
|
|
@ -93,3 +97,130 @@ async fn test_local_inference_large_prompt() {
|
|||
text.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn vision_test_model() -> Option<String> {
|
||||
std::env::var("TEST_VISION_MODEL").ok()
|
||||
}
|
||||
|
||||
/// Generate a small solid-colour 2x2 red PNG as raw bytes.
|
||||
fn tiny_red_png() -> Vec<u8> {
|
||||
vec![
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, // 1x1
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, // RGB, 8-bit
|
||||
0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, // IDAT chunk
|
||||
0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x36, 0x28, 0x19,
|
||||
0x00, // compressed pixel data
|
||||
0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, // IEND
|
||||
]
|
||||
}
|
||||
|
||||
/// Test that a vision-capable local model can process a message with an embedded image
|
||||
/// and produce a text response without crashing.
|
||||
///
|
||||
/// Requires TEST_VISION_MODEL to be set to a downloaded vision model.
|
||||
/// Example:
|
||||
/// TEST_VISION_MODEL="unsloth/gemma-4-E4B-it-GGUF:Q4_K_M" \
|
||||
/// cargo test -p goose --test local_inference_integration test_local_inference_vision -- --ignored
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_local_inference_vision_produces_output() {
|
||||
let model_id = match vision_test_model() {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
eprintln!(
|
||||
"Skipping vision test: TEST_VISION_MODEL not set. \
|
||||
Set it to a vision-capable model like unsloth/gemma-4-E4B-it-GGUF:Q4_K_M"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let model_config = ModelConfig::new(&model_id).expect("valid model config");
|
||||
let provider = create("local", model_config.clone(), Vec::new())
|
||||
.await
|
||||
.expect("provider creation should succeed");
|
||||
|
||||
let image_bytes = tiny_red_png();
|
||||
let image_b64 = BASE64_STANDARD.encode(&image_bytes);
|
||||
|
||||
let system = "You are a helpful assistant. Describe images briefly.";
|
||||
let messages = vec![Message::user()
|
||||
.with_text("What color is this image?")
|
||||
.with_image(image_b64, "image/png")];
|
||||
|
||||
let mut stream = provider
|
||||
.stream(&model_config, "test-vision-session", system, &messages, &[])
|
||||
.await
|
||||
.expect("stream should start for vision input");
|
||||
|
||||
let mut got_text = false;
|
||||
let mut collected_text = String::new();
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
let (msg, _usage) = result.expect("stream item should be Ok");
|
||||
if let Some(m) = msg {
|
||||
got_text = true;
|
||||
collected_text.push_str(&m.as_concat_text());
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
got_text,
|
||||
"vision stream should produce at least one text message"
|
||||
);
|
||||
assert!(
|
||||
!collected_text.is_empty(),
|
||||
"vision response should contain text"
|
||||
);
|
||||
println!("Vision response: {collected_text}");
|
||||
}
|
||||
|
||||
/// Test that sending an image to a text-only model produces a clear error
|
||||
/// rather than crashing.
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_local_inference_vision_text_only_model_graceful() {
|
||||
let model_config = ModelConfig::new(&test_model()).expect("valid model config");
|
||||
let provider = create("local", model_config.clone(), Vec::new())
|
||||
.await
|
||||
.expect("provider creation should succeed");
|
||||
|
||||
let image_bytes = tiny_red_png();
|
||||
let image_b64 = BASE64_STANDARD.encode(&image_bytes);
|
||||
|
||||
let system = "You are a helpful assistant.";
|
||||
let messages = vec![Message::user()
|
||||
.with_text("What is this?")
|
||||
.with_image(image_b64, "image/png")];
|
||||
|
||||
let mut stream = provider
|
||||
.stream(&model_config, "test-session", system, &messages, &[])
|
||||
.await
|
||||
.expect("stream should start");
|
||||
|
||||
// The stream should either produce a response with the image stripped
|
||||
// (placeholder text) or produce an error — but it must not crash.
|
||||
let mut completed = false;
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(_) => completed = true,
|
||||
Err(e) => {
|
||||
// An error about missing vision support is acceptable
|
||||
let err_msg = e.to_string();
|
||||
assert!(
|
||||
err_msg.contains("vision") || err_msg.contains("image"),
|
||||
"error should mention vision/image support, got: {err_msg}"
|
||||
);
|
||||
completed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
completed,
|
||||
"stream should complete without crashing when images sent to text-only model"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -317,6 +317,31 @@ prompt: |
|
|||
Log results to: {{ workspace_dir }}/phase3_delegation.md
|
||||
{% endif %}
|
||||
|
||||
{% if test_phases == "all" or "vision" in test_phases %}
|
||||
## 📷 PHASE 3B: Local Inference Vision Testing
|
||||
|
||||
**Prerequisites**: A vision-capable local model must be downloaded (e.g., gemma-4-E4B).
|
||||
Skip this phase if no local vision model is available.
|
||||
|
||||
### Vision Smoke Test
|
||||
1. Create a small test image:
|
||||
```
|
||||
python3 -c "import struct, zlib; raw=b'\x00\xff\x00\x00'; d=zlib.compress(raw); ihdr=b'\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00'; print('Created test.png')"
|
||||
```
|
||||
Or simply create a 1-pixel PNG test image using available tools.
|
||||
2. Verify the test image file exists and is valid.
|
||||
3. Send a message to the local vision model referencing the test image.
|
||||
4. Verify the model responds with text (not an error or crash).
|
||||
5. Verify the response acknowledges the image content.
|
||||
|
||||
### Vision Error Handling Test
|
||||
1. If a text-only local model is available, send it a message with an image attached.
|
||||
2. Verify it responds gracefully (either with a placeholder message or a clear error),
|
||||
not with a crash or FFI error.
|
||||
|
||||
Log results to: {{ workspace_dir }}/phase3b_vision.md
|
||||
{% endif %}
|
||||
|
||||
{% if test_phases == "all" or "advanced" in test_phases %}
|
||||
## 🔬 PHASE 4: Advanced Testing
|
||||
|
||||
|
|
|
|||
|
|
@ -5724,7 +5724,8 @@
|
|||
"size_bytes",
|
||||
"status",
|
||||
"recommended",
|
||||
"settings"
|
||||
"settings",
|
||||
"vision_capable"
|
||||
],
|
||||
"properties": {
|
||||
"filename": {
|
||||
|
|
@ -5733,6 +5734,14 @@
|
|||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"mmproj_status": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ModelDownloadStatus"
|
||||
}
|
||||
],
|
||||
"nullable": true
|
||||
},
|
||||
"quantization": {
|
||||
"type": "string"
|
||||
},
|
||||
|
|
@ -5752,6 +5761,9 @@
|
|||
},
|
||||
"status": {
|
||||
"$ref": "#/components/schemas/ModelDownloadStatus"
|
||||
},
|
||||
"vision_capable": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
@ -6497,11 +6509,22 @@
|
|||
"type": "number",
|
||||
"format": "float"
|
||||
},
|
||||
"image_token_estimate": {
|
||||
"type": "integer",
|
||||
"description": "Estimated tokens per image for budget planning before mtmd tokenization.\nThe actual count is determined after tokenization via `chunks.total_tokens()`.",
|
||||
"minimum": 0
|
||||
},
|
||||
"max_output_tokens": {
|
||||
"type": "integer",
|
||||
"nullable": true,
|
||||
"minimum": 0
|
||||
},
|
||||
"mmproj_size_bytes": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"description": "Size of the mmproj file in bytes, used for memory accounting.",
|
||||
"minimum": 0
|
||||
},
|
||||
"n_batch": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
|
|
@ -6542,6 +6565,10 @@
|
|||
},
|
||||
"use_mlock": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"vision_capable": {
|
||||
"type": "boolean",
|
||||
"description": "Whether this model architecture supports vision input.\nDerived from the featured model table, not user-configurable."
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -612,12 +612,14 @@ export type LoadedProvider = {
|
|||
export type LocalModelResponse = {
|
||||
filename: string;
|
||||
id: string;
|
||||
mmproj_status?: ModelDownloadStatus | null;
|
||||
quantization: string;
|
||||
recommended: boolean;
|
||||
repo_id: string;
|
||||
settings: ModelSettings;
|
||||
size_bytes: number;
|
||||
status: ModelDownloadStatus;
|
||||
vision_capable: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -821,7 +823,16 @@ export type ModelSettings = {
|
|||
enable_thinking?: boolean;
|
||||
flash_attention?: boolean | null;
|
||||
frequency_penalty?: number;
|
||||
/**
|
||||
* Estimated tokens per image for budget planning before mtmd tokenization.
|
||||
* The actual count is determined after tokenization via `chunks.total_tokens()`.
|
||||
*/
|
||||
image_token_estimate?: number;
|
||||
max_output_tokens?: number | null;
|
||||
/**
|
||||
* Size of the mmproj file in bytes, used for memory accounting.
|
||||
*/
|
||||
mmproj_size_bytes?: number;
|
||||
n_batch?: number | null;
|
||||
n_gpu_layers?: number | null;
|
||||
n_threads?: number | null;
|
||||
|
|
@ -832,6 +843,11 @@ export type ModelSettings = {
|
|||
sampling?: SamplingConfig;
|
||||
use_jinja?: boolean;
|
||||
use_mlock?: boolean;
|
||||
/**
|
||||
* Whether this model architecture supports vision input.
|
||||
* Derived from the featured model table, not user-configurable.
|
||||
*/
|
||||
vision_capable?: boolean;
|
||||
};
|
||||
|
||||
export type ModelTemplate = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { Download, Trash2, X, ChevronDown, ChevronUp, Settings2 } from 'lucide-react';
|
||||
import { Download, Trash2, X, ChevronDown, ChevronUp, Settings2, Eye } from 'lucide-react';
|
||||
import { Button } from '../../ui/button';
|
||||
import { useModelAndProvider } from '../../ModelAndProviderContext';
|
||||
import { defineMessages, useIntl } from '../../../i18n';
|
||||
|
|
@ -83,8 +83,57 @@ const i18n = defineMessages({
|
|||
id: 'localInferenceSettings.modelSettingsTitle',
|
||||
defaultMessage: 'Model settings',
|
||||
},
|
||||
vision: {
|
||||
id: 'localInferenceSettings.vision',
|
||||
defaultMessage: 'Vision',
|
||||
},
|
||||
visionEncoderDownloading: {
|
||||
id: 'localInferenceSettings.visionEncoderDownloading',
|
||||
defaultMessage: 'Vision encoder downloading…',
|
||||
},
|
||||
visionEncoderNotDownloaded: {
|
||||
id: 'localInferenceSettings.visionEncoderNotDownloaded',
|
||||
defaultMessage: 'Vision encoder not downloaded',
|
||||
},
|
||||
});
|
||||
|
||||
const VisionBadge = ({ model, intl }: { model: LocalModelResponse; intl: ReturnType<typeof useIntl> }) => {
|
||||
if (!model.vision_capable) return null;
|
||||
|
||||
const mmproj = model.mmproj_status;
|
||||
const isDownloaded = mmproj?.state === 'Downloaded';
|
||||
const isDownloading = mmproj?.state === 'Downloading';
|
||||
|
||||
if (isDownloaded) {
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1 text-xs text-green-400 bg-green-500/10 px-2 py-0.5 rounded">
|
||||
<Eye className="w-3 h-3" />
|
||||
{intl.formatMessage(i18n.vision)}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
if (isDownloading) {
|
||||
const percent = mmproj && 'progress_percent' in mmproj
|
||||
? Math.round(mmproj.progress_percent)
|
||||
: null;
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1 text-xs text-yellow-400 bg-yellow-500/10 px-2 py-0.5 rounded">
|
||||
<Eye className="w-3 h-3" />
|
||||
{intl.formatMessage(i18n.visionEncoderDownloading)}
|
||||
{percent != null && ` ${percent}%`}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<span className="inline-flex items-center gap-1 text-xs text-text-muted bg-background-subtle px-2 py-0.5 rounded">
|
||||
<Eye className="w-3 h-3" />
|
||||
{intl.formatMessage(i18n.vision)}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
|
||||
const formatBytes = (bytes: number): string => {
|
||||
if (bytes < 1024) return `${bytes}B`;
|
||||
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(0)}KB`;
|
||||
|
|
@ -128,6 +177,19 @@ export const LocalInferenceSettings = () => {
|
|||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
// Poll model list while any vision encoder is downloading
|
||||
useEffect(() => {
|
||||
const hasDownloadingMmproj = models.some(
|
||||
(m) => m.vision_capable && m.mmproj_status?.state === 'Downloading'
|
||||
);
|
||||
if (!hasDownloadingMmproj) return;
|
||||
|
||||
const interval = setInterval(() => {
|
||||
loadModels();
|
||||
}, 2000);
|
||||
return () => clearInterval(interval);
|
||||
}, [models, loadModels]);
|
||||
|
||||
const selectModel = async (modelId: string) => {
|
||||
try {
|
||||
await setConfigProvider({
|
||||
|
|
@ -365,6 +427,7 @@ export const LocalInferenceSettings = () => {
|
|||
{intl.formatMessage(i18n.recommended)}
|
||||
</span>
|
||||
)}
|
||||
<VisionBadge model={model} intl={intl} />
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<Button
|
||||
|
|
@ -414,6 +477,7 @@ export const LocalInferenceSettings = () => {
|
|||
{intl.formatMessage(i18n.recommended)}
|
||||
</span>
|
||||
)}
|
||||
<VisionBadge model={model} intl={intl} />
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
|
|
|
|||
|
|
@ -1838,6 +1838,15 @@
|
|||
"localInferenceSettings.title": {
|
||||
"defaultMessage": "Local Inference Models"
|
||||
},
|
||||
"localInferenceSettings.vision": {
|
||||
"defaultMessage": "Vision"
|
||||
},
|
||||
"localInferenceSettings.visionEncoderDownloading": {
|
||||
"defaultMessage": "Vision encoder downloading\u2026"
|
||||
},
|
||||
"localInferenceSettings.visionEncoderNotDownloaded": {
|
||||
"defaultMessage": "Vision encoder not downloaded"
|
||||
},
|
||||
"localModelManager.active": {
|
||||
"defaultMessage": "Active"
|
||||
},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue