pymake

A build system based on Build Systems à la Carte
git clone https://git.grace.moe/pymake
Log | Files | Refs | README

__init__.py (7689B)


      1 import asyncio
      2 from asyncio.futures import Future
      3 from contextlib import asynccontextmanager
      4 from contextvars import ContextVar
      5 from types import CoroutineType
      6 from typing import (
      7     Any,
      8     Awaitable,
      9     Callable,
     10     ClassVar,
     11     Generator,
     12     Generic,
     13     Protocol,
     14     TypeAlias,
     15     TypeVar,
     16     runtime_checkable,
     17 )
     18 import dataclasses
     19 
     20 # Alias confusing names that clash with the build system
     21 AIOTask = asyncio.Task
     22 AIOTaskGroup = asyncio.TaskGroup
     23 AIOEvent = asyncio.Event
     24 
     25 
     26 CovT = TypeVar("CovT", covariant=True)
     27 ConT = TypeVar("ConT", contravariant=True)
     28 
     29 DbEntry: TypeAlias = object
     30 
     31 Factory: TypeAlias = Callable[[], CovT]
     32 
     33 
     34 @runtime_checkable
     35 class HashLike(Protocol):
     36     def __hash__(self) -> int: ...
     37     def __eq__(self, other, /) -> bool: ...
     38 
     39 
     40 @dataclasses.dataclass(frozen=True)
     41 class TaskKey:
     42     rule_key: str
     43     arg_keys: tuple["TaskKey | HashLike"]
     44 
     45 
     46 @dataclasses.dataclass(slots=True)
     47 class Box(Generic[CovT]):
     48     x: CovT = dataclasses.field()
     49 
     50 
     51 build_tg_: ContextVar[AIOTaskGroup] = ContextVar("Build asyncio.TaskGroup")
     52 build_seq_: ContextVar[Box[int]] = ContextVar("Build sequence number")
     53 
     54 
     55 def build_tg():
     56     return build_tg_.get()
     57 
     58 
     59 def build_seq():
     60     seq = build_seq_.get()
     61     seq.x += 1
     62     return seq.x
     63 
     64 
     65 @asynccontextmanager
     66 async def BuildContext():
     67     async with AIOTaskGroup() as tg:
     68         tok = build_tg_.set(tg)
     69         build_seq_.set(Box(0))
     70         try:
     71             yield
     72         finally:
     73             build_tg_.reset(tok)
     74 
     75 
     76 @dataclasses.dataclass
     77 class TaskContext:
     78     task: "Task" = dataclasses.field()
     79 
     80 
     81 async def noop():
     82     pass
     83 
     84 
     85 @dataclasses.dataclass(slots=True)
     86 class LazyTask(Generic[CovT]):
     87     """
     88     Lazy AIOTask
     89 
     90     Lazy: The coroutine isn't run or spawned into a task until .start()ed or awaited.
     91     Use await directly if you intend to await -- this runs the coroutine in the current task.
     92     """
     93 
     94     coro: CoroutineType[Any, None, CovT]
     95     # Not started | Starting | Started
     96     task: None | Future[CovT] = None
     97 
     98     def start(self) -> Awaitable[CovT]:
     99         """Ensures task is started, and returns a coroutine for its completion."""
    100         if self.task is None:
    101             self.task = build_tg().create_task(self.coro)
    102         return self.task
    103 
    104     def __await__(self) -> Generator[Any, None, CovT]:
    105         if self.task is None:
    106             # Run the coroutine directly to avoid creating an additional task.
    107             self.task = asyncio.get_event_loop().create_future()
    108             try:
    109                 res = yield from self.coro.__await__()
    110                 self.task.set_result(res)
    111                 return res
    112             except BaseException as exc:
    113                 self.task.set_exception(exc)
    114                 raise
    115         else:
    116             res = yield from self.task.__await__()
    117             return res
    118 
    119 
    120 @dataclasses.dataclass(slots=True, frozen=True)
    121 class LazyFuture(Generic[CovT]):
    122     """
    123     Future that lazily starts its own computation.
    124 
    125     The handle is run the first time .start() is called or the future is awaited.
    126 
    127     For convenience, LazyFutures are also context managers that intercept exceptions.
    128     """
    129 
    130     handle: LazyTask | None = dataclasses.field(default=None)
    131     fut: Future[CovT] = dataclasses.field(
    132         default_factory=lambda: asyncio.get_event_loop().create_future()
    133     )
    134 
    135     none: ClassVar["LazyFuture[None]"]
    136 
    137     def __enter__(self):
    138         return self.fut.set_result
    139 
    140     def __exit__(self, exc_type, exc, tb):
    141         if not self.fut.done():
    142             self.fut.set_exception(exc)
    143 
    144     def start(self):
    145         """Ensures handle is started if any, then returns the future."""
    146         if self.handle:
    147             self.handle.start()
    148         return self.fut
    149 
    150     def __await__(self) -> Generator[Any, None, CovT]:
    151         """Ensures handle is started if any, then awaits the future."""
    152         if self.handle is not None:
    153             yield from self.handle.start().__await__()
    154         res = yield from self.fut.__await__()
    155         return res
    156 
    157 
    158 LazyFuture.none = LazyFuture()
    159 LazyFuture.none.fut.set_result(None)
    160 
    161 
    162 CovT = TypeVar("CovT", covariant=True)
    163 
    164 
    165 class Unset: ...
    166 
    167 
    168 unset = Unset()
    169 
    170 
    171 @dataclasses.dataclass(frozen=True)
    172 class Task(Generic[CovT]):
    173     """
    174     A task is a computation of the task's value and hash.
    175 
    176     Values are returned when other tasks fetch this task.
    177 
    178     Other tasks can also obtain this task's hash, but hashes should be viewed
    179     as a way to certify that "a value" is the same as some "other value",
    180     typically to assert that the value from a previous run is good.
    181     """
    182 
    183     key: TaskKey = dataclasses.field()
    184     value: LazyFuture[CovT] = dataclasses.field(compare=False)
    185     hash: LazyFuture[HashLike] = dataclasses.field(compare=False)
    186 
    187 
    188 @dataclasses.dataclass
    189 class TaskRun(Generic[CovT]):
    190     """
    191     A task run is a recording of how the task ran last time.
    192     """
    193 
    194     key: TaskKey = dataclasses.field()
    195     deps: dict[TaskKey, "TaskRun"] = dataclasses.field(default_factory=dict)
    196     value: CovT | Unset = dataclasses.field(default=unset)
    197     hash: HashLike | Unset = dataclasses.field(default=unset)
    198 
    199 
    200 """
    201 Traces record information of a run so work can be saved in the future.
    202 
    203 There are 3 variants, dep traces, hash traces and value traces.
    204 
    205 - A dep trace records a task's deps.
    206 - A hash trace records a task's deps, dep hashes, and *out hash*.
    207 - A value trace records a task's deps, dep hashes, *out hash, and out value*.
    208 
    209 Trace builders are classes to aid in the production of a trace.
    210 Once a trace builder can be finalized to obtain a true trace.
    211 
    212 As such, there are 4 classes:
    213 
    214 - DepTrace
    215 - HashTrace
    216 - ValueTrace
    217 - DepTraceBuilder
    218 - HashTraceBuilder
    219 - ValueTraceBuilder
    220 """
    221 
    222 
    223 @dataclasses.dataclass
    224 class DepTrace:
    225     key: TaskKey = dataclasses.field()
    226     # We still record HashLike so HashTrace and ValueTrace can subclass DepTrace.
    227     # The hash is simply ignored if we only want to know the dependencies of key.
    228     in_hashes: dict[TaskKey, HashLike] = dataclasses.field()
    229 
    230 
    231 @dataclasses.dataclass
    232 class HashTrace(DepTrace):
    233     out_hash: HashLike = dataclasses.field()
    234 
    235 
    236 @dataclasses.dataclass
    237 class ValueTrace(HashTrace, Generic[CovT]):
    238     out_value: CovT = dataclasses.field()
    239 
    240 
    241 @dataclasses.dataclass
    242 class DepTraceBuilder:
    243     key: TaskKey = dataclasses.field()
    244     in_hashes: dict[TaskKey, LazyFuture[HashLike]] = dataclasses.field(
    245         default_factory=dict
    246     )
    247 
    248     async def finalize(self) -> DepTrace:
    249         return DepTrace(
    250             key=self.key,
    251             in_hashes={k: await v for k, v in self.in_hashes.items()},
    252         )
    253 
    254 
    255 @dataclasses.dataclass
    256 class HashTraceBuilder(DepTraceBuilder):
    257     out_hash: LazyFuture[HashLike] | None = dataclasses.field(default=None)
    258 
    259     async def finalize(self) -> HashTrace:
    260         if self.out_hash is None:
    261             raise RuntimeError("Cannot finalize HashTraceBuilder without out_hash")
    262         return HashTrace(
    263             key=self.key,
    264             in_hashes={k: await v for k, v in self.in_hashes.items()},
    265             out_hash=await self.out_hash,
    266         )
    267 
    268 
    269 @dataclasses.dataclass
    270 class ValueTraceBuilder(HashTraceBuilder, Generic[CovT]):
    271     out_value: LazyFuture[CovT] | None = dataclasses.field(default=None)
    272 
    273     async def finalize(self) -> ValueTrace[CovT]:
    274         if self.out_hash is None:
    275             raise RuntimeError("Cannot finalize ValueTraceBuilder without out_hash")
    276         if self.out_value is None:
    277             raise RuntimeError("Cannot finalize ValueTraceBuilder without out_value")
    278         return ValueTrace(
    279             key=self.key,
    280             in_hashes={k: await v for k, v in self.in_hashes.items()},
    281             out_hash=await self.out_hash,
    282             out_value=await self.out_value,
    283         )