mirror of
https://github.com/unslothai/unsloth.git
synced 2026-05-17 03:56:07 +00:00
Merge branch 'main' into ci/workflow-permissions-and-fewer-skips
This commit is contained in:
commit
00f3e325ce
4 changed files with 227 additions and 28 deletions
|
|
@ -78,7 +78,7 @@ class LoggingMiddleware(BaseHTTPMiddleware):
|
|||
|
||||
|
||||
def filter_sensitive_data(logger, method_name, event_dict):
|
||||
"""Structlog processor to filter out base64 data from logs."""
|
||||
"""Structlog processor to redact native path leases from logs."""
|
||||
|
||||
def filter_value(value):
|
||||
if isinstance(value, str):
|
||||
|
|
@ -87,13 +87,7 @@ def filter_sensitive_data(logger, method_name, event_dict):
|
|||
except Exception:
|
||||
pass
|
||||
value = _NATIVE_PATH_LEASE_RE.sub(r"\1<redacted native path lease>", value)
|
||||
if (
|
||||
isinstance(value, str)
|
||||
and len(value) > 100
|
||||
and ("," in value or "/" in value)
|
||||
):
|
||||
# Likely base64 data, truncate it
|
||||
return value[:20] + "..."
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {
|
||||
k: "<redacted native path lease>"
|
||||
|
|
|
|||
|
|
@ -48,13 +48,31 @@ class ScrapeConfig:
|
|||
max_comments_per_item: int
|
||||
|
||||
|
||||
def _resolve_token(token: str) -> str:
|
||||
tok = token or os.environ.get("GH_TOKEN", "") or os.environ.get("GITHUB_TOKEN", "")
|
||||
if not tok:
|
||||
raise ValueError(
|
||||
"GitHub token is required. Set it in the recipe config or the GH_TOKEN / GITHUB_TOKEN env var."
|
||||
@dataclass(frozen = True)
|
||||
class ResolvedToken:
|
||||
value: str
|
||||
source: str
|
||||
|
||||
|
||||
def _resolve_token(token: str) -> ResolvedToken:
|
||||
if token:
|
||||
return ResolvedToken(
|
||||
value = token,
|
||||
source = "explicit token argument (recipe-level field)",
|
||||
)
|
||||
return tok
|
||||
if os.environ.get("GH_TOKEN"):
|
||||
return ResolvedToken(
|
||||
value = os.environ["GH_TOKEN"],
|
||||
source = "GH_TOKEN environment variable",
|
||||
)
|
||||
if os.environ.get("GITHUB_TOKEN"):
|
||||
return ResolvedToken(
|
||||
value = os.environ["GITHUB_TOKEN"],
|
||||
source = "GITHUB_TOKEN environment variable",
|
||||
)
|
||||
raise ValueError(
|
||||
"GitHub token is required. Set it in the recipe config or the GH_TOKEN / GITHUB_TOKEN env var."
|
||||
)
|
||||
|
||||
|
||||
def _read_jsonl(path: Path, max_rows: int | None = None):
|
||||
|
|
@ -155,7 +173,7 @@ def _flatten_commit_row(r: dict, repo: str) -> dict:
|
|||
def scrape(cfg: ScrapeConfig, base_dir: Path):
|
||||
token = _resolve_token(cfg.token)
|
||||
GitHubClient, RepoScraper = _load_impl()
|
||||
client = GitHubClient(token = token)
|
||||
client = GitHubClient(token = token.value, token_source = token.source)
|
||||
base_dir.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
# Per-resource trial limits. limit <= 0 means "all": use a very large cap.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import json
|
|||
import os
|
||||
import time
|
||||
import logging
|
||||
from datetime import timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional
|
||||
|
||||
import requests
|
||||
|
|
@ -29,16 +31,46 @@ class RateLimitError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class GitHubAuthError(RuntimeError):
|
||||
"""Raised when GitHub returns 401/403 due to invalid or insufficient credentials."""
|
||||
|
||||
|
||||
def _retry_after_seconds(value: str | None) -> int | None:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return max(0, int(value))
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
retry_at = parsedate_to_datetime(value)
|
||||
except (TypeError, ValueError, IndexError, OverflowError):
|
||||
return None
|
||||
if retry_at.tzinfo is None:
|
||||
retry_at = retry_at.replace(tzinfo = timezone.utc)
|
||||
return max(0, int(retry_at.timestamp() - time.time()))
|
||||
|
||||
|
||||
class GitHubClient:
|
||||
def __init__(
|
||||
self,
|
||||
min_remaining_graphql: int = 100,
|
||||
min_remaining_rest: int = 100,
|
||||
token: str | None = None,
|
||||
token_source: str | None = None,
|
||||
):
|
||||
token = token or os.environ.get("GH_TOKEN") or os.environ.get("GITHUB_TOKEN")
|
||||
if not token:
|
||||
raise RuntimeError("GH_TOKEN not set in environment")
|
||||
if token:
|
||||
self._token_source = (
|
||||
token_source or "explicit token argument (recipe-level field)"
|
||||
)
|
||||
elif os.environ.get("GH_TOKEN"):
|
||||
self._token_source = "GH_TOKEN environment variable"
|
||||
token = os.environ["GH_TOKEN"]
|
||||
elif os.environ.get("GITHUB_TOKEN"):
|
||||
self._token_source = "GITHUB_TOKEN environment variable"
|
||||
token = os.environ["GITHUB_TOKEN"]
|
||||
else:
|
||||
raise RuntimeError("GH_TOKEN or GITHUB_TOKEN not set in environment")
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(
|
||||
{**BASE_HEADERS, "Authorization": f"Bearer {token}"}
|
||||
|
|
@ -59,6 +91,49 @@ class GitHubClient:
|
|||
log.warning("Rate limit hit. Sleeping %ds until reset.", wait)
|
||||
time.sleep(wait)
|
||||
|
||||
def _is_rate_limit_response(self, r: "requests.Response") -> bool:
|
||||
if r.headers.get("Retry-After"):
|
||||
return True
|
||||
if r.headers.get("X-RateLimit-Remaining") == "0":
|
||||
return True
|
||||
body = (r.text or "").lower()
|
||||
return any(
|
||||
marker in body
|
||||
for marker in (
|
||||
"api rate limit exceeded",
|
||||
"rate limit exceeded",
|
||||
"secondary rate limit",
|
||||
"secondary limit",
|
||||
"abuse detection mechanism",
|
||||
"abuse detection",
|
||||
)
|
||||
)
|
||||
|
||||
def _is_auth_failure(self, r: "requests.Response") -> bool:
|
||||
"""Distinguish auth failures from rate limiting on 401/403 responses.
|
||||
|
||||
- 401: always an auth failure (invalid / expired / wrong-scope token).
|
||||
- 403: an auth failure UNLESS the response carries a clear rate-limit signal
|
||||
(Retry-After header, X-RateLimit-Remaining: 0, or GitHub's secondary /
|
||||
abuse rate-limit response text).
|
||||
"""
|
||||
if r.status_code == 401:
|
||||
return True
|
||||
if r.status_code == 403:
|
||||
return not self._is_rate_limit_response(r)
|
||||
return False
|
||||
|
||||
def _raise_auth_error(self, r: "requests.Response", endpoint: str) -> None:
|
||||
snippet = (r.text or "").strip()[:200]
|
||||
request_id = r.headers.get("X-GitHub-Request-Id")
|
||||
request_id_message = f" Request ID: {request_id}." if request_id else ""
|
||||
raise GitHubAuthError(
|
||||
f"GitHub {endpoint} returned {r.status_code} {r.reason}. "
|
||||
f"Token source: {self._token_source}. "
|
||||
f"The token is invalid, expired, or missing required scopes — "
|
||||
f"retrying will not recover.{request_id_message} Response: {snippet}"
|
||||
)
|
||||
|
||||
def _check_rate_and_wait(self, kind: str) -> None:
|
||||
if kind == "graphql":
|
||||
remaining = self.graphql_remaining
|
||||
|
|
@ -112,13 +187,14 @@ class GitHubClient:
|
|||
time.sleep(backoff)
|
||||
backoff = min(backoff * 2, 60)
|
||||
continue
|
||||
if self._is_auth_failure(r):
|
||||
self._raise_auth_error(r, "GraphQL")
|
||||
if r.status_code == 403 or r.status_code == 429:
|
||||
# Check for secondary/abuse
|
||||
retry_after = r.headers.get("Retry-After")
|
||||
if retry_after:
|
||||
t = int(retry_after)
|
||||
log.warning("Secondary rate limit. Sleep %ds.", t)
|
||||
time.sleep(t + 2)
|
||||
retry_after = _retry_after_seconds(r.headers.get("Retry-After"))
|
||||
if retry_after is not None:
|
||||
log.warning("Secondary rate limit. Sleep %ds.", retry_after)
|
||||
time.sleep(retry_after + 2)
|
||||
continue
|
||||
if self.graphql_reset:
|
||||
self._sleep_until(self.graphql_reset)
|
||||
|
|
@ -188,12 +264,15 @@ class GitHubClient:
|
|||
time.sleep(backoff)
|
||||
backoff = min(backoff * 2, 60)
|
||||
continue
|
||||
if self._is_auth_failure(r):
|
||||
self._raise_auth_error(r, "REST")
|
||||
if r.status_code in (403, 429):
|
||||
retry_after = r.headers.get("Retry-After")
|
||||
if retry_after:
|
||||
t = int(retry_after)
|
||||
log.warning("Secondary rate limit on REST. Sleep %ds.", t)
|
||||
time.sleep(t + 2)
|
||||
retry_after = _retry_after_seconds(r.headers.get("Retry-After"))
|
||||
if retry_after is not None:
|
||||
log.warning(
|
||||
"Secondary rate limit on REST. Sleep %ds.", retry_after
|
||||
)
|
||||
time.sleep(retry_after + 2)
|
||||
continue
|
||||
# Check if primary rate
|
||||
if self.rest_remaining == 0 and self.rest_reset:
|
||||
|
|
|
|||
108
studio/backend/tests/test_log_filter_no_truncation.py
Normal file
108
studio/backend/tests/test_log_filter_no_truncation.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# SPDX-License-Identifier: AGPL-3.0-only
|
||||
# Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0
|
||||
|
||||
"""
|
||||
Regression tests for studio.backend.loggers.handlers.filter_sensitive_data.
|
||||
|
||||
Context: filter_sensitive_data was originally written with a base64-detection
|
||||
heuristic that truncated any string >100 chars containing ',' or '/' down to
|
||||
20 chars + '...'. The block was dormant until PR #5246 wired the processor
|
||||
into the structlog chain to redact native-path leases. Once active, the
|
||||
heuristic ate normal log lines emitted by llama_cpp_backend (GGUF size
|
||||
summary, mmproj selection, the full llama-server command line) and any
|
||||
exception traceback that happened to contain a file path.
|
||||
|
||||
These tests pin two properties:
|
||||
|
||||
1. Long, comma- or slash-bearing log messages flow through filter_sensitive_data
|
||||
unchanged. The exact strings exercised match the call sites at
|
||||
studio/backend/core/inference/llama_cpp.py:2117, :2283, and :2312 that
|
||||
were truncated in the original bug report.
|
||||
|
||||
2. PR #5246's native-path lease redaction still fires for both the inline
|
||||
``native_path_lease=...`` regex form and the ``nativePathLease`` dict-key
|
||||
form. This guards against future regressions that strip redaction along
|
||||
with the truncation block.
|
||||
"""
|
||||
|
||||
from loggers.handlers import filter_sensitive_data
|
||||
|
||||
|
||||
def _run(event_dict):
|
||||
return filter_sensitive_data(logger = None, method_name = "info", event_dict = event_dict)
|
||||
|
||||
|
||||
class TestNoTruncation:
|
||||
def test_gguf_size_summary_survives(self):
|
||||
# Mirrors the f-string at studio/backend/core/inference/llama_cpp.py:2117
|
||||
event = (
|
||||
"GGUF size: 232.9 GB, est. KV cache: 87.0 GB, context: 259072, "
|
||||
"GPUs free: [(0, 80000), (1, 80000)], selected: [0, 1], fit: False"
|
||||
)
|
||||
out = _run({"event": event})
|
||||
assert out["event"] == event
|
||||
assert "..." not in out["event"]
|
||||
|
||||
def test_mmproj_path_survives(self):
|
||||
# Mirrors logger.info at studio/backend/core/inference/llama_cpp.py:2283
|
||||
event = (
|
||||
"Using mmproj for vision: "
|
||||
"/home/user/.cache/unsloth/models/some-vision-model-uncensored-r1-distill/mmproj-F16.gguf"
|
||||
)
|
||||
out = _run({"event": event})
|
||||
assert out["event"] == event
|
||||
|
||||
def test_llama_server_command_survives(self):
|
||||
# Mirrors logger.info at studio/backend/core/inference/llama_cpp.py:2312
|
||||
event = (
|
||||
"Starting llama-server: /home/user/.unsloth/studio/llama.cpp/build/bin/llama-server "
|
||||
"-m /home/user/.cache/unsloth/models/foo.gguf --port 8090 -c 259072 --parallel 1 "
|
||||
"--flash-attn on --mmproj /home/user/.cache/unsloth/models/mmproj-F16.gguf"
|
||||
)
|
||||
out = _run({"event": event})
|
||||
assert out["event"] == event
|
||||
|
||||
def test_traceback_with_paths_survives(self):
|
||||
traceback_str = (
|
||||
"Traceback (most recent call last):\n"
|
||||
' File "/home/user/.unsloth/studio/unsloth_studio/lib/python3.11/site-packages/'
|
||||
'studio/backend/core/inference/llama_cpp.py", line 2312, in start\n'
|
||||
' raise RuntimeError("llama-server crashed: bad alloc, /dev/shm full")\n'
|
||||
"RuntimeError: llama-server crashed: bad alloc, /dev/shm full"
|
||||
)
|
||||
out = _run({"event": "llama-server crashed", "exception": traceback_str})
|
||||
assert out["exception"] == traceback_str
|
||||
assert "..." not in out["exception"]
|
||||
|
||||
def test_nested_long_string_in_dict_survives(self):
|
||||
long_value = (
|
||||
"/very/long/path/with,many,commas,and/slashes/that/used/to/get/"
|
||||
"chopped/to/twenty/chars/file.gguf"
|
||||
)
|
||||
out = _run({"event": "load", "details": {"path": long_value}})
|
||||
assert out["details"]["path"] == long_value
|
||||
|
||||
|
||||
class TestNativePathLeaseRedactionStillWorks:
|
||||
"""Guards PR #5246's redaction from being lost alongside the truncation block."""
|
||||
|
||||
def test_inline_native_path_lease_value_redacted(self):
|
||||
event = (
|
||||
"rejected request: native_path_lease=AAAAAA.BBBBBB extra context "
|
||||
"with /some/path,values"
|
||||
)
|
||||
out = _run({"event": event})
|
||||
assert "AAAAAA.BBBBBB" not in out["event"]
|
||||
assert "<redacted native path lease>" in out["event"]
|
||||
|
||||
def test_camelcase_native_path_lease_dict_key_redacted(self):
|
||||
out = _run({"event": "load", "nativePathLease": "AAAAAA.BBBBBB"})
|
||||
assert out["nativePathLease"] == "<redacted native path lease>"
|
||||
|
||||
def test_snakecase_native_path_lease_dict_key_redacted(self):
|
||||
out = _run({"event": "load", "native_path_lease": "AAAAAA.BBBBBB"})
|
||||
assert out["native_path_lease"] == "<redacted native path lease>"
|
||||
|
||||
def test_nested_native_path_lease_key_redacted(self):
|
||||
out = _run({"event": "load", "payload": {"nativePathLease": "AAAAAA.BBBBBB"}})
|
||||
assert out["payload"]["nativePathLease"] == "<redacted native path lease>"
|
||||
Loading…
Add table
Add a link
Reference in a new issue