__init__.py (7689B)
1 import asyncio 2 from asyncio.futures import Future 3 from contextlib import asynccontextmanager 4 from contextvars import ContextVar 5 from types import CoroutineType 6 from typing import ( 7 Any, 8 Awaitable, 9 Callable, 10 ClassVar, 11 Generator, 12 Generic, 13 Protocol, 14 TypeAlias, 15 TypeVar, 16 runtime_checkable, 17 ) 18 import dataclasses 19 20 # Alias confusing names that clash with the build system 21 AIOTask = asyncio.Task 22 AIOTaskGroup = asyncio.TaskGroup 23 AIOEvent = asyncio.Event 24 25 26 CovT = TypeVar("CovT", covariant=True) 27 ConT = TypeVar("ConT", contravariant=True) 28 29 DbEntry: TypeAlias = object 30 31 Factory: TypeAlias = Callable[[], CovT] 32 33 34 @runtime_checkable 35 class HashLike(Protocol): 36 def __hash__(self) -> int: ... 37 def __eq__(self, other, /) -> bool: ... 38 39 40 @dataclasses.dataclass(frozen=True) 41 class TaskKey: 42 rule_key: str 43 arg_keys: tuple["TaskKey | HashLike"] 44 45 46 @dataclasses.dataclass(slots=True) 47 class Box(Generic[CovT]): 48 x: CovT = dataclasses.field() 49 50 51 build_tg_: ContextVar[AIOTaskGroup] = ContextVar("Build asyncio.TaskGroup") 52 build_seq_: ContextVar[Box[int]] = ContextVar("Build sequence number") 53 54 55 def build_tg(): 56 return build_tg_.get() 57 58 59 def build_seq(): 60 seq = build_seq_.get() 61 seq.x += 1 62 return seq.x 63 64 65 @asynccontextmanager 66 async def BuildContext(): 67 async with AIOTaskGroup() as tg: 68 tok = build_tg_.set(tg) 69 build_seq_.set(Box(0)) 70 try: 71 yield 72 finally: 73 build_tg_.reset(tok) 74 75 76 @dataclasses.dataclass 77 class TaskContext: 78 task: "Task" = dataclasses.field() 79 80 81 async def noop(): 82 pass 83 84 85 @dataclasses.dataclass(slots=True) 86 class LazyTask(Generic[CovT]): 87 """ 88 Lazy AIOTask 89 90 Lazy: The coroutine isn't run or spawned into a task until .start()ed or awaited. 91 Use await directly if you intend to await -- this runs the coroutine in the current task. 92 """ 93 94 coro: CoroutineType[Any, None, CovT] 95 # Not started | Starting | Started 96 task: None | Future[CovT] = None 97 98 def start(self) -> Awaitable[CovT]: 99 """Ensures task is started, and returns a coroutine for its completion.""" 100 if self.task is None: 101 self.task = build_tg().create_task(self.coro) 102 return self.task 103 104 def __await__(self) -> Generator[Any, None, CovT]: 105 if self.task is None: 106 # Run the coroutine directly to avoid creating an additional task. 107 self.task = asyncio.get_event_loop().create_future() 108 try: 109 res = yield from self.coro.__await__() 110 self.task.set_result(res) 111 return res 112 except BaseException as exc: 113 self.task.set_exception(exc) 114 raise 115 else: 116 res = yield from self.task.__await__() 117 return res 118 119 120 @dataclasses.dataclass(slots=True, frozen=True) 121 class LazyFuture(Generic[CovT]): 122 """ 123 Future that lazily starts its own computation. 124 125 The handle is run the first time .start() is called or the future is awaited. 126 127 For convenience, LazyFutures are also context managers that intercept exceptions. 128 """ 129 130 handle: LazyTask | None = dataclasses.field(default=None) 131 fut: Future[CovT] = dataclasses.field( 132 default_factory=lambda: asyncio.get_event_loop().create_future() 133 ) 134 135 none: ClassVar["LazyFuture[None]"] 136 137 def __enter__(self): 138 return self.fut.set_result 139 140 def __exit__(self, exc_type, exc, tb): 141 if not self.fut.done(): 142 self.fut.set_exception(exc) 143 144 def start(self): 145 """Ensures handle is started if any, then returns the future.""" 146 if self.handle: 147 self.handle.start() 148 return self.fut 149 150 def __await__(self) -> Generator[Any, None, CovT]: 151 """Ensures handle is started if any, then awaits the future.""" 152 if self.handle is not None: 153 yield from self.handle.start().__await__() 154 res = yield from self.fut.__await__() 155 return res 156 157 158 LazyFuture.none = LazyFuture() 159 LazyFuture.none.fut.set_result(None) 160 161 162 CovT = TypeVar("CovT", covariant=True) 163 164 165 class Unset: ... 166 167 168 unset = Unset() 169 170 171 @dataclasses.dataclass(frozen=True) 172 class Task(Generic[CovT]): 173 """ 174 A task is a computation of the task's value and hash. 175 176 Values are returned when other tasks fetch this task. 177 178 Other tasks can also obtain this task's hash, but hashes should be viewed 179 as a way to certify that "a value" is the same as some "other value", 180 typically to assert that the value from a previous run is good. 181 """ 182 183 key: TaskKey = dataclasses.field() 184 value: LazyFuture[CovT] = dataclasses.field(compare=False) 185 hash: LazyFuture[HashLike] = dataclasses.field(compare=False) 186 187 188 @dataclasses.dataclass 189 class TaskRun(Generic[CovT]): 190 """ 191 A task run is a recording of how the task ran last time. 192 """ 193 194 key: TaskKey = dataclasses.field() 195 deps: dict[TaskKey, "TaskRun"] = dataclasses.field(default_factory=dict) 196 value: CovT | Unset = dataclasses.field(default=unset) 197 hash: HashLike | Unset = dataclasses.field(default=unset) 198 199 200 """ 201 Traces record information of a run so work can be saved in the future. 202 203 There are 3 variants, dep traces, hash traces and value traces. 204 205 - A dep trace records a task's deps. 206 - A hash trace records a task's deps, dep hashes, and *out hash*. 207 - A value trace records a task's deps, dep hashes, *out hash, and out value*. 208 209 Trace builders are classes to aid in the production of a trace. 210 Once a trace builder can be finalized to obtain a true trace. 211 212 As such, there are 4 classes: 213 214 - DepTrace 215 - HashTrace 216 - ValueTrace 217 - DepTraceBuilder 218 - HashTraceBuilder 219 - ValueTraceBuilder 220 """ 221 222 223 @dataclasses.dataclass 224 class DepTrace: 225 key: TaskKey = dataclasses.field() 226 # We still record HashLike so HashTrace and ValueTrace can subclass DepTrace. 227 # The hash is simply ignored if we only want to know the dependencies of key. 228 in_hashes: dict[TaskKey, HashLike] = dataclasses.field() 229 230 231 @dataclasses.dataclass 232 class HashTrace(DepTrace): 233 out_hash: HashLike = dataclasses.field() 234 235 236 @dataclasses.dataclass 237 class ValueTrace(HashTrace, Generic[CovT]): 238 out_value: CovT = dataclasses.field() 239 240 241 @dataclasses.dataclass 242 class DepTraceBuilder: 243 key: TaskKey = dataclasses.field() 244 in_hashes: dict[TaskKey, LazyFuture[HashLike]] = dataclasses.field( 245 default_factory=dict 246 ) 247 248 async def finalize(self) -> DepTrace: 249 return DepTrace( 250 key=self.key, 251 in_hashes={k: await v for k, v in self.in_hashes.items()}, 252 ) 253 254 255 @dataclasses.dataclass 256 class HashTraceBuilder(DepTraceBuilder): 257 out_hash: LazyFuture[HashLike] | None = dataclasses.field(default=None) 258 259 async def finalize(self) -> HashTrace: 260 if self.out_hash is None: 261 raise RuntimeError("Cannot finalize HashTraceBuilder without out_hash") 262 return HashTrace( 263 key=self.key, 264 in_hashes={k: await v for k, v in self.in_hashes.items()}, 265 out_hash=await self.out_hash, 266 ) 267 268 269 @dataclasses.dataclass 270 class ValueTraceBuilder(HashTraceBuilder, Generic[CovT]): 271 out_value: LazyFuture[CovT] | None = dataclasses.field(default=None) 272 273 async def finalize(self) -> ValueTrace[CovT]: 274 if self.out_hash is None: 275 raise RuntimeError("Cannot finalize ValueTraceBuilder without out_hash") 276 if self.out_value is None: 277 raise RuntimeError("Cannot finalize ValueTraceBuilder without out_value") 278 return ValueTrace( 279 key=self.key, 280 in_hashes={k: await v for k, v in self.in_hashes.items()}, 281 out_hash=await self.out_hash, 282 out_value=await self.out_value, 283 )