commit 09ea2700c79218ee15fb8263f0fdf65657619aa4
parent 8c4e9f385ac40871964ea51fe76fa97f41c13180
Author: gracefu <81774659+gracefuu@users.noreply.github.com>
Date: Tue, 15 Apr 2025 03:33:06 +0800
Refactor, tweak, support concurrency
Diffstat:
| M | make.py | | | 324 | ++++++++++++++++++++++++++++++++++++++++++++++++------------------------------- |
1 file changed, 199 insertions(+), 125 deletions(-)
diff --git a/make.py b/make.py
@@ -2,7 +2,7 @@
pymake
------
-Design inspired by the paper `Build Systems à la Carte'
+Design inspired by the paper "Build Systems à la Carte"
- https://github.com/snowleopard/build
- https://www.microsoft.com/en-us/research/wp-content/uploads/2018/03/build-systems.pdf
@@ -17,13 +17,11 @@ As such, we will adopt mostly the same vocabulary:
In our system, we make some slight adjustments
-- In fact, we don't distinguish tasks and keys -- we pass around the tasks themselves.
- - For storage purposes, we treat task equality based on string equality of the task function.
-- We focus on implementing the suspending scheduler and constructive traces rebuilder.
+- For convenience, we automatically derive the task key from the task function, see fn_to_key.
- As with the paper, we don't handle dependency cycles, since it's unclear which key to "seed" and with what "seed value".
-- While `fetch` in the paper is a parameter that's passed around, we just have it be a global function in our case.
-- Similarly, while `rebuilder` is a global that's fixed for the whole build system, we interpret it really as just a fancy way to say @cache, and so it really makes more sense to let each task choose its rebuild strategy.
-- This means that there is no such thing as an "input" or "output"/"intermediate" key, an input key is simply a key that hasn't been wrapped by a rebuilder.
+- While `rebuilder` is a global that's fixed for the whole build system, we reinterpret it as a cache policy, and so it really makes more sense to let each task choose its cache policy as opposed to having it be global.
+- This means that there is no such thing as an "input" or "output"/"intermediate" key, an input key is simply a key that hasn't been wrapped by a caching layer.
+- We focus on implementing the suspending scheduler and constructive traces cache policy.
"""
import asyncio
@@ -35,14 +33,13 @@ import hashlib
from typing import Awaitable, Callable, Any, Concatenate, Optional
-# Rules are functions that take in python primitives (bool, int, none, str) and tasks, and output a task.
-# Tasks are coroutine functions with a single argument `build`
-# All rules must be registered with the decorator @rule.
-#
-# For convenience, rules with no arguments can also be created by decorating @task on the coroutine function directly.
+FetchFn = Callable[["Task"], Awaitable[Any]]
+TaskKey = str
+RuleKey = str
-def make_hash(o: Any) -> bytes:
+
+def _make_hash(o: Any) -> bytes:
h = hashlib.sha256()
if isinstance(o, bytes):
h.update(b"s")
@@ -53,7 +50,20 @@ def make_hash(o: Any) -> bytes:
return h.digest()
+def _fn_to_key(fn) -> str:
+ name = fn.__name__
+ source = inspect.getsource(fn)
+ h = hashlib.sha256(source.encode("utf-8")).hexdigest()[:16]
+ key = f"{name}-{len(source)}-{h}"
+ return key
+
+
class Task:
+ task_key: TaskKey
+ rule_fn: Callable[Concatenate[FetchFn, TaskKey, "Store", ...], Awaitable[Any]]
+ args: tuple
+ hash: int
+
@staticmethod
def new(rule, *args):
return Task(
@@ -61,18 +71,18 @@ class Task:
rule.rule_key,
*(arg.task_key if hasattr(arg, "task_key") else arg for arg in args),
),
- rule,
+ rule.rule_fn,
*args,
)
- def __init__(self, task_key, rule, *args):
+ def __init__(self, task_key, rule_fn, *args):
self.task_key = task_key
- self.rule = rule
+ self.rule_fn = rule_fn
self.args = args
self.hash = hash(self.task_key)
- def __call__(self, fetch: "Fetch"):
- return self.rule.rule_fn(fetch, *self.args)
+ def __call__(self, fetch: "FetchFn", store: "Store"):
+ return self.rule_fn(fetch, self.task_key, store, *self.args)
def __repr__(self):
return repr(self.task_key)
@@ -84,30 +94,14 @@ class Task:
return self.hash
-class Fetch:
- fetch_fn: Callable[[Task], Awaitable[Any]]
- task: Task
- build: "Build"
-
- def __init__(self, fetch_fn, task, build):
- self.fetch_fn = fetch_fn
- self.task = task
- self.build = build
-
- def __call__(self, dep: Task):
- return self.fetch_fn(dep)
-
-
class Rule:
- rule_key: str
- rule_fn: Callable[Concatenate[Fetch, ...], Awaitable[Any]]
+ rule_key: RuleKey
+ rule_fn: Callable[Concatenate[FetchFn, TaskKey, "Store", ...], Awaitable[Any]]
+ hash: int
@staticmethod
def new(rule_fn):
- name = rule_fn.__name__
- source = inspect.getsource(rule_fn)
- h = hashlib.sha256(source.encode("utf-8")).hexdigest()[:16]
- return Rule(f"{name}-{len(source)}-{h}", rule_fn)
+ return Rule(_fn_to_key(rule_fn), rule_fn)
def __init__(self, rule_key, rule_fn):
self.rule_key = rule_key
@@ -125,38 +119,43 @@ class Rule:
class Rules:
+ rules: dict[RuleKey, Rule]
+
def __init__(self):
self.rules = dict()
- def rule(self):
- def decorator(rule_fn):
- rule = Rule.new(rule_fn)
- self.rules[rule.rule_key] = rule
- return rule
+ def rule(self, rule_fn):
+ @self.rawrule
+ @functools.wraps(rule_fn)
+ def wrapped(fetch, task_key, store, *args):
+ return rule_fn(fetch, *args)
- return decorator
+ return wrapped
+
+ def rawrule(self, rule_fn):
+ rule = Rule.new(rule_fn)
+ self.rules[rule.rule_key] = rule
+ return rule
def eval_task_key(self, task_key) -> Optional[Task]:
rule_key, *arg_keys = task_key
if rule_key not in self.rules:
return None
+ rule = self.rules[rule_key]
+ args = []
for arg in arg_keys:
if isinstance(arg, tuple) and arg[0] not in self.rules:
return None
- rule = self.rules[rule_key]
- args = (
- self.eval_task_key(arg) if isinstance(arg, tuple) else arg
- for arg in arg_keys
- )
+ args.append(self.eval_task_key(arg) if isinstance(arg, tuple) else arg)
return rule(*args)
# Wraps a rule so it only gets rebuilt if the constructive traces don't match
- def ctRebuilder(self):
+ def cache(self):
def decorator(rule: Rule):
@functools.wraps(rule.rule_fn)
- async def new_rule_fn(fetch, *args):
- past_runs = fetch.build.key_info[fetch.task.task_key]
- output_value = fetch.build.key_value[fetch.task.task_key]
+ async def new_rule_fn(fetch: FetchFn, task_key: str, store: "Store", *args):
+ past_runs = store.key_info[task_key]
+ output_value = store.key_value[task_key]
possible_values = []
for past_inputs, past_value in past_runs:
for past_input_key, past_input_hash in past_inputs:
@@ -164,7 +163,7 @@ class Rules:
if not input_task:
break
current_input_value = await fetch(input_task)
- if make_hash(current_input_value) != past_input_hash:
+ if _make_hash(current_input_value) != past_input_hash:
break
else:
if output_value == past_value:
@@ -172,22 +171,19 @@ class Rules:
possible_values.append(past_value)
if possible_values:
- fetch.build.key_value[fetch.task.task_key] = possible_values[0]
+ store.key_value[task_key] = possible_values[0]
return possible_values[0]
new_inputs = []
- async def track(task: Task):
+ async def track_fetch(task: Task):
result = await fetch(task)
- new_inputs.append((task.task_key, make_hash(result)))
+ new_inputs.append((task.task_key, _make_hash(result)))
return result
- task = Task.new(rule, *args)
- new_value = await task(Fetch(track, task, fetch.build))
- fetch.build.key_value[fetch.task.task_key] = new_value
- fetch.build.key_info[fetch.task.task_key].append(
- (new_inputs, new_value)
- )
+ new_value = await rule.rule_fn(track_fetch, task_key, store, *args)
+ store.key_value[task_key] = new_value
+ store.key_info[task_key].append((new_inputs, new_value))
return new_value
wrapped_rule = Rule(rule.rule_key, new_rule_fn)
@@ -197,53 +193,23 @@ class Rules:
return decorator
-rules = Rules()
-rule = rules.rule()
-ctRebuilder = rules.ctRebuilder()
-
-
-# Example rule
-@ctRebuilder
-@rule
-async def eg_six(fetch: Fetch):
- _ = fetch
- print(f"{6=}")
- return 6
-
+_rules = Rules()
+rule = _rules.rule
+rawrule = _rules.rawrule
+cache = _rules.cache()
-# Example of a rule with a dependency
-@rule
-async def eg_thirtysix(fetch: Fetch):
- # Rules should be called to get tasks
- # In this case, the rule had 0 tasks
- # task = eg_six()
- # call fetch to mark a dependency,
- # (and begins execution of it in parallel if possible.)
- # `await` it get the result of the dependency
- six1 = await fetch(eg_six())
- six2 = await fetch(eg_six())
- print(f"{six1 * six2=}")
- return six1 * six2
-
-# Tasks can be parameterized based on other tasks or just normal values.
-@rule
-async def eg_multiply_add(fetch: Fetch, taskA: Task, taskB: Task, num: int):
- a, b = await asyncio.gather(fetch(taskA), fetch(taskB))
- print(f"{a * b + num=}")
- return a * b + num
-
-
-def fNone():
+def _fNone():
return None
-class Build:
+class Store:
def __init__(self, filename, rules):
self.filename = filename
self.rules = rules
- self.key_value = collections.defaultdict(fNone)
+ self.mutex = asyncio.Semaphore()
+ self.key_value = collections.defaultdict(_fNone)
self.key_info = collections.defaultdict(list)
try:
@@ -252,42 +218,150 @@ class Build:
except:
pass
+ def save(self):
+ with open(self.filename, "wb") as f:
+ pickle.dump((self.key_value, self.key_info), f)
+
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
- with open(self.filename, "wb") as f:
- pickle.dump((self.key_value, self.key_info), f)
+ self.save()
+
+
+_background_tasks = set()
+
+
+def detach(*args, **kwargs):
+ task = asyncio.create_task(*args, **kwargs)
+ _background_tasks.add(task)
+ task.add_done_callback(_background_tasks.discard)
+
+
+class SuspendingFetch:
+ def __init__(self, store: Store):
+ self.store = store
+ self.done = dict()
+ self.waits = dict()
+
+ async def __call__(self, task: Task):
+ await self.fetch(task)
+ await self.wait()
+
+ async def wait(self):
+ while _background_tasks:
+ await asyncio.gather(*_background_tasks)
+
+ async def fetch(self, task: Task):
+ task_key = task.task_key
+ wait = None
+ event = None
+ if task_key in self.done:
+ return self.done[task_key]
+ if task_key in self.waits:
+ wait = self.waits[task_key]
+
+ if wait:
+ await wait.wait()
+ return self.done[task_key]
+
+ event = self.waits[task_key] = asyncio.Event()
+ result = await task(self.fetch, self.store)
+ self.done[task_key] = result
+ event.set()
+
+ return result
+
+
+# Example rules
+# Observe the general pattern that every rule is called to get a task, which can then be fetched.
+# res = await fetch(rule(task_args...))
+
+
+@cache
+@rule
+async def _eg_six(fetch: FetchFn):
+ _ = fetch
+ six = 6
+ print(f"{six=}")
+ return six
@rule
-async def eg_file(fetch: Fetch, filename: str):
- print("file", filename)
+async def _eg_thirtysix(fetch: FetchFn):
+ # Here we await the dependencies serially.
+ # The second dependency cannot start until the first finishes.
+ six1 = await fetch(_eg_six())
+ six2 = await fetch(_eg_six())
+ print(f"{six1*six2=}")
+ return six1 * six2
+
+
+@rule
+async def _eg_multiply_add(fetch: FetchFn, taskA: Task, taskB: Task, num: int):
+ # Here we await the dependencies in parallel.
+ a, b = await asyncio.gather(fetch(taskA), fetch(taskB))
+ await asyncio.sleep(0.1)
+ print(f"{a*b+num=}")
+ return a * b + num
+
+
+# When interfacing with inputs or in general anything outside the build system,
+# Do NOT add @ctRebuilder, as it makes the task only rerun if a dependency was known to be modified.
+# In this case, we have no real dependencies, and our output depends on the filesystem.
+# So we leave out @ctRebuilder to ensure we always check that the file has not changed.
+@rule
+async def _eg_file(fetch: FetchFn, filename: str):
+ _ = fetch
+ await asyncio.sleep(0.1)
with open(filename, "r") as f:
- return f.read()
+ contents = f.readlines()
+ print("file", filename, "\n" + "".join(contents[1:5]), end="")
+ return contents
+
+
+# Semaphores can be used to limit concurrency
+_sem = asyncio.Semaphore(4)
-@ctRebuilder
+@cache
@rule
-async def eg_rec(fetch: Fetch, i: int):
- print("rec", i)
- j = len(await fetch(eg_file("make.py"))) % 2
- if i > 0:
- await fetch(eg_rec(i - 1 - j))
- await fetch(eg_rec(i - 1 - j))
+async def _eg_rec(fetch: FetchFn, i: int):
+ if i // 3 - 1 >= 0:
+ # Instead of awaiting, dependencies can also be detached and run in the background.
+ detach(fetch(_eg_rec(i // 2 - 1)))
+ detach(fetch(_eg_rec(i // 3 - 1)))
else:
- print("\n".join((await fetch(eg_file("make.py"))).splitlines()[:9]))
-
+ detach(fetch(_eg_file("make.py")))
+
+ # Use semaphore to limit concurrency easily
+ async with _sem:
+ print("+ rec", i)
+ # Simulate some hard work
+ await asyncio.sleep(0.1)
+ print("- rec", i)
+
+
+async def run_examples():
+ # To actually run the build system,
+ # 1) Create the store
+ # Use context manager to ensure the store is saved automatically when exiting
+ with Store("make.db", _rules) as store:
+ # 2) Create the fetch callable
+ fetch = SuspendingFetch(store)
+ # 3) Use it to await tasks
+ await fetch(_eg_rec(1234))
+ await asyncio.gather(
+ fetch(_eg_thirtysix()), fetch(_eg_multiply_add(_eg_six(), _eg_six(), 6))
+ )
-if __name__ == "__main__":
- with Build("make.db", rules) as build:
- done = dict()
+ # Note that `fetch(...)` will wait for all detached jobs to complete before returning.
+ # You may choose to use the lower level `fetch.fetch(...)` function instead, which does not wait for detached jobs.
+ # You must then ensure `fetch.wait()` is called later to wait for detached jobs to complete.
+ await fetch.fetch(_eg_rec(2345))
+ await fetch.fetch(_eg_rec(3456))
+ await fetch.wait()
- async def fetch(task: Task):
- if task.task_key in done:
- return done[task.task_key]
- result = await task(Fetch(fetch, task, build))
- done[task.task_key] = result
- return result
- asyncio.run(fetch(eg_rec(10)))
+if __name__ == "__main__":
+ asyncio.run(run_examples())