Source code for conda_tasks.cache

"""Task output caching using file fingerprints.

Cache entries are stored in a platform-appropriate directory via
``platformdirs``.  Each project gets a subdirectory keyed by a hash
of the project root path. Within that, each task has a JSON file
containing fingerprints of its inputs and outputs.

A fast pre-check using ``(mtime, size)`` tuples avoids SHA-256
hashing when files haven't been touched since the last run.
"""

from __future__ import annotations

import glob
import hashlib
import json
import os
from pathlib import Path
from typing import TYPE_CHECKING

from platformdirs import user_cache_dir

if TYPE_CHECKING:
    from typing import Any


def _cache_root() -> Path:
    """Return the platform-appropriate root cache directory for conda-tasks."""
    return Path(user_cache_dir("conda-tasks"))


def _project_cache_dir(project_root: Path) -> Path:
    """Return the per-project cache directory, creating it if necessary."""
    key = hashlib.sha256(str(project_root.resolve()).encode()).hexdigest()[:16]
    d = _cache_root() / key
    d.mkdir(parents=True, exist_ok=True)
    return d


def _cache_file(project_root: Path, task_name: str) -> Path:
    """Return the JSON cache file path for a specific task."""
    return _project_cache_dir(project_root) / f"{task_name}.json"


def _file_stat(path: str) -> tuple[float, int] | None:
    """Return ``(mtime, size)`` for *path*, or None if missing."""
    try:
        st = os.stat(path)
        return (st.st_mtime, st.st_size)
    except OSError:
        return None


def _file_sha256(path: str) -> str:
    """Return the hex SHA-256 digest of the file at *path*."""
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            h.update(chunk)
    return h.hexdigest()


def _expand_globs(patterns: list[str], cwd: Path) -> list[str]:
    """Expand glob patterns relative to *cwd*, return sorted paths."""
    result: set[str] = set()
    for pattern in patterns:
        expanded = glob.glob(str(cwd / pattern), recursive=True)
        result.update(expanded)
    return sorted(result)


def _fingerprint_files(paths: list[str]) -> dict[str, dict[str, Any]]:
    """Build a fingerprint dict: ``{path: {mtime, size, sha256}}``."""
    fp: dict[str, dict[str, Any]] = {}
    for p in paths:
        stat = _file_stat(p)
        if stat is None:
            continue
        fp[p] = {
            "mtime": stat[0],
            "size": stat[1],
            "sha256": _file_sha256(p),
        }
    return fp


def _compute_entry(
    cmd: str,
    env: dict[str, str],
    input_files: list[str],
    output_files: list[str],
) -> dict[str, Any]:
    """Compute a cache entry from current state."""
    cmd_hash = hashlib.sha256(cmd.encode()).hexdigest()
    env_hash = hashlib.sha256(json.dumps(env, sort_keys=True).encode()).hexdigest()
    return {
        "cmd_hash": cmd_hash,
        "env_hash": env_hash,
        "inputs": _fingerprint_files(input_files),
        "outputs": _fingerprint_files(output_files),
    }


[docs] def is_cached( project_root: Path, task_name: str, cmd: str, env: dict[str, str], input_patterns: list[str], output_patterns: list[str], cwd: Path, ) -> bool: """Check whether the task can be skipped (cache hit). Returns True only when all of the following hold: 1. A cache entry exists for the task. 2. The command and env hashes match. 3. All input files match by ``(mtime, size)`` -- falling back to SHA-256 if the fast check fails. 4. All output files still exist and match. """ cf = _cache_file(project_root, task_name) if not cf.exists(): return False try: cached = json.loads(cf.read_text(encoding="utf-8")) except (json.JSONDecodeError, OSError): return False input_files = _expand_globs(input_patterns, cwd) output_files = _expand_globs(output_patterns, cwd) current = _compute_entry(cmd, env, input_files, output_files) if cached.get("cmd_hash") != current["cmd_hash"]: return False if cached.get("env_hash") != current["env_hash"]: return False if not _files_match(cached.get("inputs", {}), current["inputs"]): return False if not _files_match(cached.get("outputs", {}), current["outputs"]): return False return True
def _files_match(cached: dict[str, Any], current: dict[str, Any]) -> bool: """Compare two fingerprint dicts. Fast path: if ``(mtime, size)`` match, skip SHA-256 comparison. """ if set(cached.keys()) != set(current.keys()): return False for path, cur in current.items(): prev = cached.get(path) if prev is None: return False if prev["mtime"] == cur["mtime"] and prev["size"] == cur["size"]: continue if prev["sha256"] != cur["sha256"]: return False return True
[docs] def save_cache( project_root: Path, task_name: str, cmd: str, env: dict[str, str], input_patterns: list[str], output_patterns: list[str], cwd: Path, ) -> None: """Write or update the cache entry for a task.""" input_files = _expand_globs(input_patterns, cwd) output_files = _expand_globs(output_patterns, cwd) entry = _compute_entry(cmd, env, input_files, output_files) cf = _cache_file(project_root, task_name) cf.write_text(json.dumps(entry, indent=2), encoding="utf-8")