commit 6dfb840cac304b0a6a64afe4c5f8bf89ea0e33cb
parent 078b4e2875257e2c5e63dd93fcc4fbd96e46542e
Author: gracefu <81774659+gracefuu@users.noreply.github.com>
Date: Wed, 23 Apr 2025 02:22:03 +0800
Add once / exec utils, use from ... import ... instead of import ... to make the code harder to read
Diffstat:
5 files changed, 168 insertions(+), 26 deletions(-)
diff --git a/make3/exec.py b/make3/exec.py
@@ -0,0 +1,93 @@
+from asyncio.streams import StreamReader, StreamWriter
+from asyncio.subprocess import Process, PIPE
+from collections import namedtuple
+from asyncio import create_subprocess_exec, create_subprocess_shell, gather
+import sys
+
+
+class ShellResult(namedtuple("ShellResult", "stdout stderr returncode")):
+ __slots__ = ()
+
+ @property
+ def utf8stdout(self):
+ return self.stdout.decode("utf-8")
+
+ @property
+ def utf8stderr(self):
+ return self.stderr.decode("utf-8")
+
+
+EchoNothing = 0
+EchoStdout = 1
+EchoStderr = 2
+EchoAll = 3
+
+
+async def _exec_writer(
+ proc: Process,
+ ostream: StreamWriter,
+ input: bytes | bytearray | memoryview | None = None,
+):
+ if input is not None:
+ ostream.write(input)
+ await ostream.drain()
+ return await proc.wait()
+
+
+async def _exec_reader(istream: StreamReader, ostream=None):
+ contents = b""
+ while not istream.at_eof():
+ chunk = await istream.read(4096 * 16)
+ contents += chunk
+ if ostream:
+ ostream.write(chunk)
+ ostream.flush()
+ return contents
+
+
+async def communicate_echo_wait(
+ proc: Process,
+ input: bytes | bytearray | memoryview | None = None,
+ echo: int = EchoNothing,
+) -> ShellResult:
+ stdout, stderr, returncode = await gather(
+ _exec_reader(
+ proc.stdout, # type: ignore
+ sys.stdout.buffer if echo & EchoStdout else None,
+ ),
+ _exec_reader(
+ proc.stderr, # type: ignore
+ sys.stderr.buffer if echo & EchoStderr else None,
+ ),
+ _exec_writer(
+ proc,
+ proc.stdin, # type: ignore
+ input,
+ ),
+ )
+ return ShellResult(stdout, stderr, returncode)
+
+
+async def exec(
+ prog,
+ *args,
+ input: bytes | bytearray | memoryview | None = None,
+ echo: int = EchoNothing,
+) -> ShellResult:
+ return await communicate_echo_wait(
+ await create_subprocess_exec(prog, *args, stdin=PIPE, stdout=PIPE, stderr=PIPE),
+ input,
+ echo,
+ )
+
+
+async def shell(
+ cmd,
+ input: bytes | bytearray | memoryview | None = None,
+ echo: int = EchoNothing,
+) -> ShellResult:
+ return await communicate_echo_wait(
+ await create_subprocess_shell(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE),
+ input,
+ echo,
+ )
diff --git a/make3/main.py b/make3/main.py
@@ -0,0 +1,9 @@
+import sys
+from asyncio import gather
+
+
+async def make_main(globals, default_target="all()"):
+ targets = sys.argv[1:]
+ if not targets:
+ targets.append(default_target)
+ await gather(*(eval(target, globals=globals) for target in targets))
diff --git a/make3/once.py b/make3/once.py
@@ -0,0 +1,50 @@
+from concurrent.futures import Executor
+from typing import Any, Awaitable, Callable, ParamSpec, TypeVar
+from asyncio import Future, get_event_loop
+from functools import partial, wraps
+from inspect import signature
+
+
+def once():
+ def decorator(f):
+ futs: dict[tuple[Any, ...], Future] = {}
+ sig = signature(f)
+
+ @wraps(f)
+ async def wrapped(*args, **kwargs):
+ bound_args = sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+ key = tuple(bound_args.arguments.values())
+ if key in futs:
+ return await futs[key]
+ futs[key] = Future()
+ try:
+ res = await f(*args, **kwargs)
+ futs[key].set_result(res)
+ return res
+ except BaseException as e:
+ futs[key].set_exception(e)
+ raise
+
+ return wrapped
+
+ return decorator
+
+
+def in_executor(executor: Executor | None = None):
+ Args = ParamSpec("Args")
+ T = TypeVar("T")
+
+ def decorator(f: Callable[Args, T]) -> Callable[Args, Awaitable[T]]:
+ @wraps(f)
+ def wrapped(*args, **kwargs):
+ if kwargs:
+ return get_event_loop().run_in_executor(
+ executor, partial(f, **kwargs), *args
+ )
+ else:
+ return get_event_loop().run_in_executor(executor, f, *args)
+
+ return wrapped
+
+ return decorator
diff --git a/make3/pickler.py b/make3/pickler.py
@@ -1,10 +1,10 @@
from io import BytesIO
from types import CellType, CodeType, FunctionType
from typing import Any
-import copyreg
-import importlib
-import inspect
-import pickle
+from copyreg import dispatch_table
+from importlib import import_module
+from inspect import getmodule
+from pickle import Pickler
def pickle_code_type(code: CodeType):
@@ -46,7 +46,7 @@ def unpickle_cell_type(*args):
def pickle_function_type(f: FunctionType):
- mod = inspect.getmodule(f)
+ mod = getmodule(f)
return (
unpickle_function_type,
(
@@ -62,20 +62,18 @@ def pickle_function_type(f: FunctionType):
def unpickle_function_type(code, mod_name, closure):
- return FunctionType(
- code, globals=importlib.import_module(mod_name).__dict__, closure=closure
- )
+ return FunctionType(code, globals=import_module(mod_name).__dict__, closure=closure)
-class FunctionPickler(pickle.Pickler):
- dispatch_table = copyreg.dispatch_table.copy()
+class FunctionPickler(Pickler):
+ dispatch_table = dispatch_table.copy()
dispatch_table[CodeType] = pickle_code_type
dispatch_table[CellType] = pickle_cell_type
def reducer_override(self, obj): # type: ignore
if type(obj) is not FunctionType:
return NotImplemented
- obj_mod = inspect.getmodule(obj)
+ obj_mod = getmodule(obj)
if obj_mod is None:
return NotImplemented
if obj.__name__ in dir(obj_mod):
diff --git a/make3/rebuild.py b/make3/rebuild.py
@@ -1,13 +1,11 @@
from .pickler import FunctionPickler, pickle_with
from typing import Any, Callable, overload
-import contextvars
-import functools
+from contextvars import ContextVar, copy_context
+from functools import wraps
import pickle
-rerun_db_var: contextvars.ContextVar[dict] = contextvars.ContextVar("rerun_db")
-rerun_changes_var: contextvars.ContextVar[list[tuple[Any, bytes]]] = (
- contextvars.ContextVar("rerun_changes")
-)
+rerun_db_var: ContextVar[dict] = ContextVar("rerun_db")
+rerun_changes_var: ContextVar[list[tuple[Any, bytes]]] = ContextVar("rerun_changes")
def with_rerun_context(rerun_changes, f, /, *args, **kwargs):
@@ -36,7 +34,7 @@ def cache_conditionally(
load_fn=lambda cached_result, /, *_, **__: cached_result,
):
def decorator(fn):
- @functools.wraps(fn)
+ @wraps(fn)
def wrapped(*args, **kwargs):
db = rerun_db_var.get()
keys = keys_fn(*args, **kwargs)
@@ -58,15 +56,9 @@ def cache_conditionally(
else:
return load_fn(db[db_key + ("result",)], *args, **kwargs)
- context = contextvars.copy_context()
+ context = copy_context()
rerun_changes = []
- result = context.run(
- with_rerun_context,
- rerun_changes,
- fn,
- *args,
- **kwargs,
- )
+ result = context.run(with_rerun_context, rerun_changes, fn, *args, **kwargs)
db[db_key + ("rerun_changes",)] = rerun_changes
db[db_key + ("result",)] = store_fn(result, *args, **kwargs)
return result