commit f2d98d626b90fa7f9fc2ef4f4b788b7949bb7baa
parent 1ae6809bce7c79ec5bf63b22790ec3d9f0d943cf
Author: gracefu <81774659+gracefuu@users.noreply.github.com>
Date: Wed, 23 Apr 2025 05:23:55 +0800
Incorporate @once and @cache_conditionally by converting the entire rebuild stack into async/await, which involves some ast rewriting...
Diffstat:
8 files changed, 115 insertions(+), 61 deletions(-)
diff --git a/make3/__init__.py b/make3/__init__.py
@@ -1,3 +1,6 @@
+from .exec import *
+from .helpers import *
from .io import *
+from .main import *
+from .once import *
from .rebuild import *
-from .helpers import *
diff --git a/make3/exec.py b/make3/exec.py
@@ -1,7 +1,7 @@
+from asyncio import create_subprocess_exec, create_subprocess_shell, gather
from asyncio.streams import StreamReader, StreamWriter
from asyncio.subprocess import Process, PIPE
from collections import namedtuple
-from asyncio import create_subprocess_exec, create_subprocess_shell, gather
import sys
diff --git a/make3/helpers.py b/make3/helpers.py
@@ -1,29 +1,35 @@
from .io import open
-from .rebuild import cache_conditionally, rerun_if_changed
+from .once import once
+from .rebuild import cache_conditionally, rerun_if_changed, rerun_always
import hashlib
import os
-def file_modtime(f: int | str | bytes | os.PathLike[str] | os.PathLike[bytes]):
+async def file_modtime(f: int | str | bytes | os.PathLike[str] | os.PathLike[bytes]):
return os.stat(f).st_mtime_ns
+@once()
@cache_conditionally(lambda f, *args: (f.name, *args))
-def _file_hash(f: open, skip_if_modtime_matches=True):
+async def _file_hash(f: open, skip_if_modtime_matches=True):
if skip_if_modtime_matches:
- rerun_if_changed(lambda: file_modtime(f.fileno()))
+ rerun_if_changed()(await file_modtime(f.fileno()))
else:
- rerun_if_changed(False, lambda: True) # always rerun
+ rerun_always()
h = hashlib.sha256()
- for chunk in f:
- h.update(chunk)
+ while True:
+ chunk = f.read1()
+ if chunk:
+ h.update(chunk)
+ else:
+ break
d = h.hexdigest()
# print("hash", f.name, d)
return d
-def file_hash(f: open | bytes | str, skip_if_modtime_matches=True):
+async def file_hash(f: open | bytes | str, skip_if_modtime_matches=True):
if isinstance(f, bytes) or isinstance(f, str):
with open(f, "rb") as _f:
- return _file_hash(_f, skip_if_modtime_matches)
- return _file_hash(f, skip_if_modtime_matches)
+ return await _file_hash(_f, skip_if_modtime_matches)
+ return await _file_hash(f, skip_if_modtime_matches)
diff --git a/make3/main.py b/make3/main.py
@@ -1,9 +1,11 @@
-import sys
+from .rebuild import Rerunner
from asyncio import gather
+import sys
async def make_main(globals, default_target="all()"):
targets = sys.argv[1:]
if not targets:
targets.append(default_target)
- await gather(*(eval(target, globals=globals) for target in targets))
+ with Rerunner():
+ await gather(*(eval(target, globals=globals) for target in targets))
diff --git a/make3/once.py b/make3/once.py
@@ -1,7 +1,6 @@
-from concurrent.futures import Executor
-from typing import Any, Awaitable, Callable, ParamSpec, TypeVar
-from asyncio import Future, get_event_loop
-from functools import partial, wraps
+from typing import Any
+from asyncio import Future
+from functools import wraps
from inspect import signature
@@ -29,22 +28,3 @@ def once():
return wrapped
return decorator
-
-
-def in_executor(executor: Executor | None = None):
- Args = ParamSpec("Args")
- T = TypeVar("T")
-
- def decorator(f: Callable[Args, T]) -> Callable[Args, Awaitable[T]]:
- @wraps(f)
- def wrapped(*args, **kwargs):
- if kwargs:
- return get_event_loop().run_in_executor(
- executor, partial(f, **kwargs), *args
- )
- else:
- return get_event_loop().run_in_executor(executor, f, *args)
-
- return wrapped
-
- return decorator
diff --git a/make3/rebuild.py b/make3/rebuild.py
@@ -1,13 +1,15 @@
+from asyncio import create_task
from contextvars import ContextVar, copy_context
from copyreg import dispatch_table
-from functools import wraps
+from functools import update_wrapper, wraps
from importlib import import_module
from inspect import getmodule
from io import BytesIO
from pickle import Pickler
from types import CellType, CodeType, FunctionType
from typing import Any
-from typing import Any, Callable, overload
+import ast
+import inspect
import pickle
@@ -96,34 +98,92 @@ rerun_db_var: ContextVar[dict] = ContextVar("rerun_db")
rerun_changes_var: ContextVar[list[tuple[Any, bytes]]] = ContextVar("rerun_changes")
-def with_rerun_context(rerun_changes, f, /, *args, **kwargs):
+async def with_rerun_context(rerun_changes, f, /, *args, **kwargs):
rerun_changes_var.set(rerun_changes)
- return f(*args, **kwargs)
+ return await f(*args, **kwargs)
-@overload
-def rerun_if_changed(now: Callable, *, pickler_cls: type = FunctionPickler): ...
-@overload
-def rerun_if_changed(
- now: Any, later: Callable, *, pickler_cls: type = FunctionPickler
-): ...
-def rerun_if_changed(now, later=None, *, pickler_cls=FunctionPickler):
- later_pkl = pickle_with(pickler_cls, now if later is None else later)
- rerun_changes_var.get().append((now() if later is None else now, later_pkl))
+def rewrite_rerun_if_changed(frame=None):
+ def decorator(fn):
+ s = inspect.getsource(fn).splitlines()
+ i = 0
+ while not s[i].startswith("async def"):
+ i += 1
+ s = s[i:]
+ s = "\n".join(s)
+ a = ast.parse(s).body[0]
+
+ class RewriteCalls(ast.NodeTransformer):
+ def visit_Expr(self, node: ast.Expr):
+ if (
+ isinstance(node.value, ast.Call)
+ and isinstance(node.value.func, ast.Call)
+ and isinstance(node.value.func.func, ast.Name)
+ and node.value.func.func.id == "rerun_if_changed"
+ ):
+ if len(node.value.func.args) == 0:
+ node.value.func.args.append(node.value.args[0])
+ out = ast.AsyncFunctionDef(
+ "_",
+ ast.arguments(),
+ [ast.Return(node.value.args[0])],
+ [node.value.func],
+ )
+ return out
+ return node
+
+ a = ast.fix_missing_locations(RewriteCalls().visit(a))
+ # print(ast.unparse(a))
+ frame_ = frame if frame else inspect.currentframe().f_back # type: ignore
+ exec(ast.unparse(a), frame_.f_globals, frame_.f_locals, closure=fn.__closure__) # type: ignore
+ fn_ = list(frame_.f_locals.values())[-1] # type: ignore
+
+ fn_ = update_wrapper(fn_, fn)
+ return fn_
+
+ return decorator
+
+
+class _RunLaterNow: ...
+
+
+def rerun_if_changed(now: Any = _RunLaterNow, *, pickler_cls=FunctionPickler):
+ def decorator(later):
+ later_pkl = pickle_with(pickler_cls, later)
+ if now is _RunLaterNow:
+ raise RuntimeError(
+ "Should have been preprocessed away by the cache_conditionally macro"
+ )
+ else:
+ rerun_changes_var.get().append((now, later_pkl))
+
+ return decorator
def rerun_if(f):
- return rerun_if_changed(False, lambda: bool(f()))
+ @rerun_if_changed(False)
+ async def _():
+ return bool(await f())
+
+
+def rerun_always():
+ @rerun_if_changed(False)
+ async def _():
+ return True
def cache_conditionally(
keys_fn=lambda *args, **kwargs: (args, tuple(sorted(kwargs.items()))),
store_fn=lambda result, /, *_, **__: result,
load_fn=lambda cached_result, /, *_, **__: cached_result,
+ rewrite=True,
):
def decorator(fn):
+ if rewrite:
+ fn = rewrite_rerun_if_changed(inspect.currentframe().f_back)(fn) # type: ignore
+
@wraps(fn)
- def wrapped(*args, **kwargs):
+ async def wrapped(*args, **kwargs):
db = rerun_db_var.get()
keys = keys_fn(*args, **kwargs)
db_key = ("track", fn.__qualname__, keys)
@@ -136,7 +196,7 @@ def cache_conditionally(
for old_val, f_pkl in old_rerun_changes:
try:
f_unpkled = pickle.loads(f_pkl)
- val = f_unpkled()
+ val = await f_unpkled()
if old_val != val:
break
except BaseException:
@@ -144,9 +204,11 @@ def cache_conditionally(
else:
return load_fn(db[db_key + ("result",)], *args, **kwargs)
- context = copy_context()
rerun_changes = []
- result = context.run(with_rerun_context, rerun_changes, fn, *args, **kwargs)
+ result = await create_task(
+ with_rerun_context(rerun_changes, fn, *args, **kwargs),
+ context=copy_context(),
+ )
db[db_key + ("rerun_changes",)] = rerun_changes
db[db_key + ("result",)] = store_fn(result, *args, **kwargs)
return result
diff --git a/tar-sketch/a.txt b/tar-sketch/a.txt
@@ -1 +1 @@
-Wed Apr 23 02:56:16 AM +08 2025
+Wed Apr 23 04:53:21 AM +08 2025
diff --git a/tar-sketch/tar2.py b/tar-sketch/tar2.py
@@ -1,3 +1,4 @@
+import asyncio
import sys
sys.path.append("..")
@@ -6,17 +7,17 @@ import subprocess
@cache_conditionally()
-def tar(manifest=b"manifest", output=b"archive.tar.gz"):
+async def tar(manifest=b"manifest", output=b"archive.tar.gz"):
with open(manifest, "rb") as manifest_f:
manifest_lines = manifest_f.read().splitlines()
- rerun_if_changed(lambda: file_hash(manifest_f))
+ rerun_if_changed()(await file_hash(manifest_f))
for fname in manifest_lines:
- rerun_if_changed(lambda: file_hash(fname))
+ rerun_if_changed()(await file_hash(fname))
print("tar", "cvzf", output, *manifest_lines)
subprocess.run([b"tar", b"cvzf", output, *manifest_lines])
- rerun_if_changed(lambda: file_hash(output))
+ rerun_if_changed()(await file_hash(output))
with Rerunner():
- tar()
+ asyncio.run(tar())