mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-17 03:56:07 +00:00
Studio: forward llm-structured output_format as llama-server response_format
Local GGUF runs of llm-structured columns used to generate the full
max_tokens budget before the prompt-level "return JSON in a ```json
fence" instruction got parsed. Small models (e.g. gemma-4-E2B-it)
routinely broke format, so each row took ~65s and frequently failed
with "No parsable JSON structure within ```json markdown fence".
For any local-provider model_config referenced by an llm-structured
column, clone the model_config and inject response_format into the
clone's inference_parameters. Uses llama.cpp server's flat shape
(tools/server/README.md):
{"type": "json_schema", "schema": <output_format>}
Not the OpenAI-nested form; data_designer's OpenAI adapter forwards
response_format verbatim via facade._COMPLETION_REQUEST_FIELDS, and
llama-server's documented schema path expects the flat variant.
The clone is per (model_alias, column) so:
- llm-text / llm-judge columns that share the same alias keep
free-form sampling.
- Each structured column gets its own schema, so columns with
different output_formats don't collide.
Effect on gemma-4-E2B-it demos: every row parses cleanly, and the
model terminates immediately after the closing brace instead of
running to max_tokens. Net wall-clock is usually faster even though
grammar-constrained sampling is slightly slower per token.
This commit is contained in:
parent
f5a38652c2
commit
a61b4cc9a7
1 changed files with 142 additions and 18 deletions
|
|
@ -5,8 +5,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
import copy
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
|
@ -94,14 +95,104 @@ def _used_llm_model_aliases(recipe: dict[str, Any]) -> set[str]:
|
|||
return aliases
|
||||
|
||||
|
||||
def _inject_local_providers(recipe: dict[str, Any], request: Request) -> None:
|
||||
def _inject_local_structured_response_format(
|
||||
recipe: dict[str, Any], local_provider_names: set[str]
|
||||
) -> None:
|
||||
"""For each llm-structured column that targets a local-provider model_config,
|
||||
clone the model_config and inject an OpenAI ``response_format`` with the
|
||||
column's ``output_format`` JSON schema. The column is rewritten to point at
|
||||
the clone so llm-text / llm-judge columns that share the same alias keep
|
||||
free-form sampling.
|
||||
|
||||
Without this, data_designer only injects a prompt-level "return JSON in a
|
||||
```json fence" instruction. Small GGUF models frequently break format,
|
||||
wasting the full ``max_tokens`` budget per row and then failing to parse.
|
||||
Forwarding ``response_format`` lets llama-server apply grammar-constrained
|
||||
sampling from the JSON schema, which guarantees a parseable response and
|
||||
terminates early.
|
||||
"""
|
||||
columns = recipe.get("columns")
|
||||
model_configs = recipe.get("model_configs")
|
||||
if not isinstance(columns, list) or not isinstance(model_configs, list):
|
||||
return
|
||||
|
||||
# alias -> model_config (only configs referencing a local provider qualify).
|
||||
alias_to_local_mc: dict[str, dict[str, Any]] = {}
|
||||
for mc in model_configs:
|
||||
if not isinstance(mc, dict):
|
||||
continue
|
||||
if mc.get("provider") in local_provider_names and isinstance(
|
||||
mc.get("alias"), str
|
||||
):
|
||||
alias_to_local_mc[mc["alias"]] = mc
|
||||
|
||||
if not alias_to_local_mc:
|
||||
return
|
||||
|
||||
# Clone per (alias, column) so each llm-structured column gets its own
|
||||
# schema without leaking response_format onto other columns that share the
|
||||
# same base alias.
|
||||
seen_clone_aliases: set[str] = {
|
||||
mc.get("alias") for mc in model_configs if isinstance(mc.get("alias"), str)
|
||||
}
|
||||
new_configs: list[dict[str, Any]] = []
|
||||
for column in columns:
|
||||
if not isinstance(column, dict):
|
||||
continue
|
||||
if column.get("column_type") != "llm-structured":
|
||||
continue
|
||||
alias = column.get("model_alias")
|
||||
if not isinstance(alias, str) or alias not in alias_to_local_mc:
|
||||
continue
|
||||
output_format = column.get("output_format")
|
||||
if not isinstance(output_format, dict) or not output_format:
|
||||
continue
|
||||
base_mc = alias_to_local_mc[alias]
|
||||
column_name = column.get("name") or "structured"
|
||||
clone_alias_base = f"{alias}__{column_name}_structured"
|
||||
clone_alias = clone_alias_base
|
||||
counter = 1
|
||||
while clone_alias in seen_clone_aliases:
|
||||
counter += 1
|
||||
clone_alias = f"{clone_alias_base}_{counter}"
|
||||
seen_clone_aliases.add(clone_alias)
|
||||
|
||||
clone = copy.deepcopy(base_mc)
|
||||
clone["alias"] = clone_alias
|
||||
params = clone.get("inference_parameters")
|
||||
if not isinstance(params, dict):
|
||||
params = {}
|
||||
clone["inference_parameters"] = params
|
||||
# llama.cpp server shape (tools/server/README.md): the schema sits
|
||||
# directly under response_format, not nested in a json_schema object
|
||||
# the way OpenAI's Chat Completions API expects. llama-server converts
|
||||
# the schema to a GBNF grammar and applies it during sampling.
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"schema": output_format,
|
||||
}
|
||||
new_configs.append(clone)
|
||||
column["model_alias"] = clone_alias
|
||||
|
||||
if new_configs:
|
||||
model_configs.extend(new_configs)
|
||||
|
||||
|
||||
def _inject_local_providers(
|
||||
recipe: dict[str, Any], request: Request
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Mutate recipe dict in-place: for any provider with is_local=True,
|
||||
generate a JWT and fill in the endpoint pointing at this server.
|
||||
fill in the endpoint pointing at this server and inject a short-lived
|
||||
internal sk-unsloth-* API key for workflow auth.
|
||||
|
||||
Returns the row id of the minted internal key (so the caller can
|
||||
revoke it on job completion) or ``None`` when no local provider is
|
||||
actually reachable from an LLM column.
|
||||
"""
|
||||
providers = recipe.get("model_providers")
|
||||
if not providers:
|
||||
return
|
||||
return None
|
||||
|
||||
# Collect local providers and pop is_local from ALL dicts unconditionally.
|
||||
# Strict `is True` guard so malformed payloads (is_local: 1,
|
||||
|
|
@ -115,7 +206,7 @@ def _inject_local_providers(recipe: dict[str, Any], request: Request) -> None:
|
|||
local_indices.append(i)
|
||||
|
||||
if not local_indices:
|
||||
return
|
||||
return None
|
||||
|
||||
endpoint = _resolve_local_v1_endpoint(request)
|
||||
|
||||
|
|
@ -138,6 +229,7 @@ def _inject_local_providers(recipe: dict[str, Any], request: Request) -> None:
|
|||
}
|
||||
|
||||
token = ""
|
||||
internal_key_id: Optional[int] = None
|
||||
if local_names & referenced_providers:
|
||||
# Verify a model is loaded.
|
||||
# NOTE: This is a point-in-time check (TOCTOU). The model could be unloaded
|
||||
|
|
@ -158,18 +250,23 @@ def _inject_local_providers(recipe: dict[str, Any], request: Request) -> None:
|
|||
"No model loaded in Chat. Load a model first, then run the recipe."
|
||||
)
|
||||
|
||||
from auth.authentication import (
|
||||
create_access_token,
|
||||
) # deferred: avoids circular import
|
||||
from auth import storage # deferred: avoids circular import
|
||||
|
||||
# Uses the "unsloth" admin subject. If the user changes their password,
|
||||
# the JWT secret rotates and this token becomes invalid mid-run.
|
||||
# Acceptable for v1 - recipes typically finish well within one session.
|
||||
token = create_access_token(
|
||||
subject = "unsloth",
|
||||
expires_delta = timedelta(hours = 24),
|
||||
desktop = _request_has_desktop_access_token(request),
|
||||
# Mint an internal sk-unsloth-* key scoped to this workflow run.
|
||||
# Uses the unified API-key issuance path (one mint/revoke/verify
|
||||
# surface instead of a second JWT code path). The key is marked
|
||||
# internal so it is hidden from the user's API-key list, and the
|
||||
# caller revokes it when the job terminates.
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(hours = 24)
|
||||
).isoformat()
|
||||
token, row = storage.create_api_key(
|
||||
username = "unsloth",
|
||||
name = "data-recipe workflow",
|
||||
expires_at = expires_at,
|
||||
internal = True,
|
||||
)
|
||||
internal_key_id = int(row["id"])
|
||||
|
||||
# Defensively strip any stale "external"-only fields the frontend may
|
||||
# have left on the dict (extra_headers/extra_body/api_key_env). The UI
|
||||
|
|
@ -197,6 +294,13 @@ def _inject_local_providers(recipe: dict[str, Any], request: Request) -> None:
|
|||
if mc.get("provider") in local_names:
|
||||
mc["skip_health_check"] = True
|
||||
|
||||
# Forward each llm-structured column's output_format as an OpenAI
|
||||
# response_format so llama-server uses grammar-constrained sampling and
|
||||
# small GGUFs stop wasting the full max_tokens budget on broken JSON.
|
||||
_inject_local_structured_response_format(recipe, local_names)
|
||||
|
||||
return internal_key_id
|
||||
|
||||
|
||||
def _normalize_run_name(value: Any) -> str | None:
|
||||
if value is None:
|
||||
|
|
@ -240,21 +344,41 @@ def create_job(payload: RecipePayload, request: Request):
|
|||
) from exc
|
||||
|
||||
try:
|
||||
_inject_local_providers(recipe, request)
|
||||
internal_api_key_id = _inject_local_providers(recipe, request)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code = 400, detail = str(exc)) from exc
|
||||
|
||||
mgr = get_job_manager()
|
||||
try:
|
||||
job_id = mgr.start(recipe = recipe, run = run)
|
||||
job_id = mgr.start(
|
||||
recipe = recipe,
|
||||
run = run,
|
||||
internal_api_key_id = internal_api_key_id,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
# Clean up the workflow key if the job could not be started.
|
||||
if internal_api_key_id is not None:
|
||||
_revoke_internal_api_key_safe(internal_api_key_id)
|
||||
raise HTTPException(status_code = 409, detail = str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
if internal_api_key_id is not None:
|
||||
_revoke_internal_api_key_safe(internal_api_key_id)
|
||||
raise HTTPException(status_code = 400, detail = str(exc)) from exc
|
||||
|
||||
return {"job_id": job_id}
|
||||
|
||||
|
||||
def _revoke_internal_api_key_safe(key_id: int) -> None:
|
||||
"""Best-effort revoke of a workflow-minted key; swallow any error so
|
||||
that revocation failures never mask the caller's own error path."""
|
||||
try:
|
||||
from auth import storage # deferred: avoids circular import
|
||||
|
||||
storage.revoke_internal_api_key(key_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/status")
|
||||
def job_status(job_id: str):
|
||||
mgr = get_job_manager()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue