commit 71e96d09f2ea76e14d1759006ccfe19211a3ddce
parent 7a0b376f9f09c65da7ba683ed4cc1034dffc345e
Author: gracefu <81774659+gracefuu@users.noreply.github.com>
Date: Sun, 20 Apr 2025 06:42:32 +0800
Replace tar2 primitive with rerun_if_changed
Diffstat:
3 files changed, 53 insertions(+), 33 deletions(-)
diff --git a/tar-sketch/README.md b/tar-sketch/README.md
@@ -4,4 +4,4 @@ The `tar` function opens the `manifest` file and puts files listed in it into an
Thus, the dependencies of `tar` is dynamic, we want to rerun `tar` if either the manifest changes (static) or if one of its input files changed (dynamic).
tar.py is a version where the dependency tracking logic is written by hand.
-tar2.py is a version using a minimal build system with a single primitive (rerun_if).
+tar2.py is a version using a minimal build system with a single primitive (rerun_if_changed).
diff --git a/tar-sketch/a.txt b/tar-sketch/a.txt
@@ -1 +1 @@
-Sun Apr 20 06:11:20 AM +08 2025
+Sun Apr 20 07:04:58 AM +08 2025
diff --git a/tar-sketch/tar2.py b/tar-sketch/tar2.py
@@ -2,6 +2,8 @@ import contextvars
import functools
import inspect
import pickle
+import types
+from typing import Any
db: dict = dict()
@@ -19,16 +21,43 @@ def get_or_set(d, k, v):
return v
-rerun_ifs_var = contextvars.ContextVar("rerun_ifs")
+rerun_changes_var: contextvars.ContextVar[list[tuple[str, Any]]] = (
+ contextvars.ContextVar("rerun_changes")
+)
+rerun_locals_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
+ "rerun_locals"
+)
-def with_rerun_context(rerun_ifs, f, /, *args, **kwargs):
- rerun_ifs_var.set(rerun_ifs)
+def with_rerun_context(rerun_changes, rerun_locals, f, /, *args, **kwargs):
+ rerun_changes_var.set(rerun_changes)
+ rerun_locals_var.set(rerun_locals)
return f(*args, **kwargs)
+class UseEval:
+ pass
+
+
+class UseCaller:
+ pass
+
+
+def rerun_if_changed(f_str, current_value: Any = UseEval):
+ rerun_changes_var.get().append(
+ (
+ f_str,
+ (
+ eval(f_str, locals=rerun_locals_var.get())
+ if current_value == UseEval
+ else current_value
+ ),
+ )
+ )
+
+
def rerun_if(f_str):
- rerun_ifs_var.get().append(f_str)
+ return rerun_if_changed(f"bool({f_str})", False)
def cache_conditionally(
@@ -36,32 +65,21 @@ def cache_conditionally(
if_cached_fn=lambda cached_result, /, *_, **__: cached_result,
):
def decorator(fn):
- sig = inspect.signature(fn)
- defaults = {
- p.name: p.default
- for p in sig.parameters.values()
- if p.default != inspect.Parameter.empty
- }
- sig = sig.replace(
- parameters=tuple(
- inspect.Parameter(name=p.name, kind=p.kind)
- for p in sig.parameters.values()
- )
- )
- sig_str = sig.format().lstrip("(").rstrip(")")
+ signature = inspect.signature(fn)
@functools.wraps(fn)
def wrapped(*args, **kwargs):
keys = keys_fn(*args, **kwargs)
+ bound_args = signature.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+ rerun_locals = bound_args.arguments
if ("track", "result", fn.__qualname__, keys) in db:
- old_rerun_cond_list = get_or_set(
- db, ("track", "rerun_ifs", fn.__qualname__, keys), []
+ old_rerun_changes = get_or_set(
+ db, ("track", "rerun_changes", fn.__qualname__, keys), []
)
- for f in old_rerun_cond_list:
- new_kwargs = defaults.copy()
- new_kwargs.update(kwargs)
- res = eval(f"lambda {sig_str}: {f}")(*args, **new_kwargs)
- if res:
+ for expr, old_val in old_rerun_changes:
+ res = eval(expr, locals=rerun_locals)
+ if res != old_val:
break
else:
return if_cached_fn(
@@ -69,9 +87,11 @@ def cache_conditionally(
)
context = contextvars.copy_context()
- rerun_ifs = []
- result = context.run(with_rerun_context, rerun_ifs, fn, *args, **kwargs)
- db[("track", "rerun_ifs", fn.__qualname__, keys)] = rerun_ifs
+ rerun_changes = []
+ result = context.run(
+ with_rerun_context, rerun_changes, rerun_locals, fn, *args, **kwargs
+ )
+ db[("track", "rerun_changes", fn.__qualname__, keys)] = rerun_changes
db[("track", "result", fn.__qualname__, keys)] = result
return result
@@ -92,7 +112,7 @@ file_modtime = lambda f: os.stat(f.fileno()).st_mtime_ns
@cache_conditionally(lambda f: f.name)
def _file_hash(f: BufferedReader):
- rerun_if(f"file_modtime(f) != {repr(file_modtime(f))}")
+ rerun_if_changed("file_modtime(f)")
h = hashlib.sha256()
for chunk in f:
h.update(chunk)
@@ -112,13 +132,13 @@ def file_hash(f: BufferedReader | bytes | str):
def tar(manifest=b"manifest", output=b"archive.tar.gz"):
with open(manifest, "rb") as manifest_f:
manifest_lines = manifest_f.read().splitlines()
- rerun_if(f"file_hash({repr(manifest)}) != {repr(file_hash(manifest_f))}")
+ rerun_if_changed(f"file_hash(manifest)")
for fname in manifest_lines:
- rerun_if(f"file_hash({repr(fname)}) != {repr(file_hash(fname))}")
+ rerun_if_changed(f"file_hash({repr(fname)})")
print("tar", "cvzf", output, *manifest_lines)
subprocess.run([b"tar", b"cvzf", output, *manifest_lines])
- rerun_if(f"file_hash({repr(output)}) != {repr(file_hash(output))}")
+ rerun_if_changed(f"file_hash({repr(output)})")
tar()