diff --git a/helpers/self_update.py b/helpers/self_update.py index 89b1386d6..326bbdaa3 100644 --- a/helpers/self_update.py +++ b/helpers/self_update.py @@ -3,6 +3,8 @@ from __future__ import annotations import os import re import subprocess +import tempfile +import time from datetime import UTC, datetime from pathlib import Path from typing import Any, Literal, TypedDict @@ -20,11 +22,14 @@ BRANCH_OPTIONS = [ SUPPORTED_BRANCHES = {option["value"] for option in BRANCH_OPTIONS} BACKUP_CONFLICT_POLICIES = {"rename", "overwrite", "fail"} MIN_SELECTOR_VERSION = (1, 0) +REMOTE_BRANCH_TAG_CACHE_TTL_SECONDS = 60.0 UPDATE_FILE_PATH = Path("/exe/a0-self-update.yaml") STATUS_FILE_PATH = Path("/exe/a0-self-update-status.yaml") LOG_FILE_PATH = Path("/exe/a0-self-update.log") +_remote_branch_tag_cache: dict[str, tuple[float, set[str]]] = {} + class PendingUpdateConfig(TypedDict): branch: Literal["main", "testing", "development"] @@ -113,6 +118,10 @@ def get_repo_dir(repo_dir: str | Path | None = None) -> Path: return Path(__file__).resolve().parents[1] +def _get_official_remote_url() -> str: + return f"https://github.com/{OFFICIAL_REPO_AUTHOR}/{OFFICIAL_REPO_NAME}.git" + + def _run_git(repo_dir: str | Path, *args: str) -> str: completed = subprocess.run( ["git", "-C", str(get_repo_dir(repo_dir)), *args], @@ -192,7 +201,37 @@ def _get_branch_reference_names(branch: str) -> list[str]: return [f"origin/{normalized_branch}", normalized_branch] -def _get_branch_merged_tags( +def _get_remote_branch_merged_tags(branch: str) -> set[str]: + normalized_branch = branch.strip().lower() + if normalized_branch not in SUPPORTED_BRANCHES: + return set() + + cached = _remote_branch_tag_cache.get(normalized_branch) + now = time.monotonic() + if cached and now - cached[0] <= REMOTE_BRANCH_TAG_CACHE_TTL_SECONDS: + return set(cached[1]) + + with tempfile.TemporaryDirectory(prefix="a0-self-update-tags-") as temp_dir: + repository = Path(temp_dir) + _run_git(repository, "init", "--bare") + _run_git( + repository, + "fetch", + "--quiet", + "--prune", + "--filter=blob:none", + "--tags", + _get_official_remote_url(), + f"refs/heads/{normalized_branch}:refs/remotes/origin/{normalized_branch}", + ) + output = _run_git(repository, "tag", "--merged", f"refs/remotes/origin/{normalized_branch}") + merged_tags = {line.strip() for line in output.splitlines() if line.strip()} + + _remote_branch_tag_cache[normalized_branch] = (now, merged_tags) + return set(merged_tags) + + +def _get_local_branch_merged_tags( branch: str, repo_dir: str | Path | None = None, ) -> set[str]: @@ -207,6 +246,19 @@ def _get_branch_merged_tags( return set() +def _get_branch_merged_tags( + branch: str, + repo_dir: str | Path | None = None, +) -> set[str]: + try: + remote_tags = _get_remote_branch_merged_tags(branch) + if remote_tags: + return remote_tags + except Exception: + pass + return _get_local_branch_merged_tags(branch, repo_dir=repo_dir) + + def _parse_selector_version(tag: str) -> tuple[int, int] | None: match = re.fullmatch(r"v(\d+)\.(\d+)", tag.strip()) if not match: diff --git a/tests/test_self_update_tag_filter.py b/tests/test_self_update_tag_filter.py index 3387e3032..1897bfba3 100644 --- a/tests/test_self_update_tag_filter.py +++ b/tests/test_self_update_tag_filter.py @@ -36,6 +36,36 @@ def test_self_update_selector_tags_are_sorted_numerically(): ] +def test_self_update_branch_filter_prefers_remote_branch_tags(monkeypatch): + monkeypatch.setattr( + self_update.git, + "get_remote_releases", + lambda author, repo: types.SimpleNamespace( + error="", + releases=[ + types.SimpleNamespace(tag="v1.2"), + types.SimpleNamespace(tag="v1.1"), + types.SimpleNamespace(tag="v1.0"), + ], + ), + ) + monkeypatch.setattr( + self_update, + "_get_remote_branch_merged_tags", + lambda branch: {"v1.1", "v1.0"}, + ) + monkeypatch.setattr( + self_update, + "_get_local_branch_merged_tags", + lambda branch, repo_dir=None: set(), + ) + + tags, error = self_update.get_available_tags("development") + + assert error == "" + assert tags == ["v1.1", "v1.0"] + + def test_self_update_frontend_filters_old_tag_suggestions(): store_path = ( PROJECT_ROOT