rebuild.py (7403B)
1 from asyncio import create_task 2 from contextvars import ContextVar, copy_context 3 from copyreg import dispatch_table 4 from functools import update_wrapper, wraps 5 from importlib import import_module 6 from inspect import getmodule 7 from io import BytesIO 8 from pickle import Pickler 9 from types import CellType, CodeType, FunctionType 10 from typing import Any 11 import ast 12 import inspect 13 import pickle 14 15 16 def pickle_code_type(code: CodeType): 17 return ( 18 unpickle_code_type, 19 ( 20 code.co_argcount, 21 code.co_posonlyargcount, 22 code.co_kwonlyargcount, 23 code.co_nlocals, 24 code.co_stacksize, 25 code.co_flags, 26 code.co_code, 27 code.co_consts, 28 code.co_names, 29 code.co_varnames, 30 code.co_filename, 31 code.co_name, 32 code.co_qualname, 33 code.co_firstlineno, 34 code.co_linetable, 35 code.co_exceptiontable, 36 code.co_freevars, 37 code.co_cellvars, 38 ), 39 ) 40 41 42 def unpickle_code_type(*args): 43 return CodeType(*args) 44 45 46 def pickle_cell_type(cell: CellType): 47 return (unpickle_cell_type, (cell.cell_contents,)) 48 49 50 def unpickle_cell_type(*args): 51 return CellType(*args) 52 53 54 def pickle_function_type(f: FunctionType): 55 mod = getmodule(f) 56 return ( 57 unpickle_function_type, 58 ( 59 f.__code__, 60 mod.__name__ if mod is not None else None, 61 ( 62 tuple(CellType(cell.cell_contents) for cell in f.__closure__) 63 if f.__closure__ 64 else None 65 ), 66 ), 67 ) 68 69 70 def unpickle_function_type(code, mod_name, closure): 71 return FunctionType(code, globals=import_module(mod_name).__dict__, closure=closure) 72 73 74 class FunctionPickler(Pickler): 75 dispatch_table = dispatch_table.copy() 76 dispatch_table[CodeType] = pickle_code_type 77 dispatch_table[CellType] = pickle_cell_type 78 79 def reducer_override(self, obj): # type: ignore 80 if type(obj) is not FunctionType: 81 return NotImplemented 82 obj_mod = getmodule(obj) 83 if obj_mod is None: 84 return NotImplemented 85 if obj.__name__ in dir(obj_mod): 86 return NotImplemented 87 return pickle_function_type(obj) 88 89 90 def pickle_with(pickler_cls: type, obj: Any) -> bytes: 91 i = BytesIO() 92 pickler_cls(i).dump(obj) 93 i.seek(0) 94 return i.read() 95 96 97 rerun_db_var: ContextVar[dict] = ContextVar("rerun_db") 98 rerun_changes_var: ContextVar[list[tuple[Any, bytes]]] = ContextVar("rerun_changes") 99 100 101 async def with_rerun_context(rerun_changes, f, /, *args, **kwargs): 102 rerun_changes_var.set(rerun_changes) 103 return await f(*args, **kwargs) 104 105 106 def rewrite_rerun_if_changed(frame=None): 107 def decorator(fn): 108 s = inspect.getsource(fn).splitlines() 109 i = 0 110 while not s[i].startswith("async def"): 111 i += 1 112 s = s[i:] 113 s = "\n".join(s) 114 a = ast.parse(s).body[0] 115 116 class RewriteCalls(ast.NodeTransformer): 117 def visit_Expr(self, node: ast.Expr): 118 if ( 119 isinstance(node.value, ast.Call) 120 and isinstance(node.value.func, ast.Call) 121 and isinstance(node.value.func.func, ast.Name) 122 and node.value.func.func.id == "rerun_if_changed" 123 ): 124 if len(node.value.func.args) == 0: 125 node.value.func.args.append(node.value.args[0]) 126 out = ast.AsyncFunctionDef( 127 "_", 128 ast.arguments(), 129 [ast.Return(node.value.args[0])], 130 [node.value.func], 131 ) 132 return out 133 return node 134 135 a = ast.fix_missing_locations(RewriteCalls().visit(a)) 136 # print(ast.unparse(a)) 137 frame_ = frame if frame else inspect.currentframe().f_back # type: ignore 138 exec(ast.unparse(a), frame_.f_globals, frame_.f_locals, closure=fn.__closure__) # type: ignore 139 fn_ = list(frame_.f_locals.values())[-1] # type: ignore 140 141 fn_ = update_wrapper(fn_, fn) 142 return fn_ 143 144 return decorator 145 146 147 class _RunLaterNow: ... 148 149 150 def rerun_if_changed(now: Any = _RunLaterNow, *, pickler_cls=FunctionPickler): 151 def decorator(later): 152 later_pkl = pickle_with(pickler_cls, later) 153 if now is _RunLaterNow: 154 raise RuntimeError( 155 "Should have been preprocessed away by the cache_conditionally macro" 156 ) 157 else: 158 rerun_changes_var.get().append((now, later_pkl)) 159 160 return decorator 161 162 163 def rerun_if(f): 164 @rerun_if_changed(False) 165 async def _(): 166 return bool(await f()) 167 168 169 def rerun_always(): 170 @rerun_if_changed(False) 171 async def _(): 172 return True 173 174 175 def cache_conditionally( 176 keys_fn=lambda *args, **kwargs: (args, tuple(sorted(kwargs.items()))), 177 store_fn=lambda result, /, *_, **__: result, 178 load_fn=lambda cached_result, /, *_, **__: cached_result, 179 rewrite=True, 180 ): 181 def decorator(fn): 182 if rewrite: 183 fn = rewrite_rerun_if_changed(inspect.currentframe().f_back)(fn) # type: ignore 184 185 @wraps(fn) 186 async def wrapped(*args, **kwargs): 187 db = rerun_db_var.get() 188 keys = keys_fn(*args, **kwargs) 189 db_key = ("track", fn.__qualname__, keys) 190 if db_key + ("result",) in db: 191 if db_key + ("rerun_changes",) not in db: 192 old_rerun_changes = [] 193 db[db_key + ("rerun_changes",)] = old_rerun_changes 194 else: 195 old_rerun_changes = db[db_key + ("rerun_changes",)] 196 for old_val, f_pkl in old_rerun_changes: 197 try: 198 f_unpkled = pickle.loads(f_pkl) 199 val = await f_unpkled() 200 if old_val != val: 201 break 202 except BaseException: 203 break 204 else: 205 return load_fn(db[db_key + ("result",)], *args, **kwargs) 206 207 rerun_changes = [] 208 result = await create_task( 209 with_rerun_context(rerun_changes, fn, *args, **kwargs), 210 context=copy_context(), 211 ) 212 db[db_key + ("rerun_changes",)] = rerun_changes 213 db[db_key + ("result",)] = store_fn(result, *args, **kwargs) 214 return result 215 216 return wrapped 217 218 return decorator 219 220 221 class Rerunner: 222 def __init__(self, db_filename=b".makedb", db_file=None): 223 if db_file: 224 self.db_file = db_file 225 else: 226 self.db_file = open(db_filename, "a+b") 227 self.db_file.seek(0) 228 229 def __enter__(self): 230 self.db_file.__enter__() 231 try: 232 self.db = pickle.load(self.db_file) 233 except pickle.PickleError: 234 self.db = dict() 235 except EOFError: 236 self.db = dict() 237 self.var_tok = rerun_db_var.set(self.db) 238 return self 239 240 def __exit__(self, ty, exc, tb): 241 rerun_db_var.reset(self.var_tok) 242 if exc is None: 243 self.db_file.seek(0) 244 self.db_file.truncate(0) 245 pickle.dump(self.db, self.db_file) 246 self.db_file.__exit__(ty, exc, tb)