pymake

A build system based on Build Systems à la Carte
git clone https://git.grace.moe/pymake
Log | Files | Refs | README

rebuild.py (7403B)


      1 from asyncio import create_task
      2 from contextvars import ContextVar, copy_context
      3 from copyreg import dispatch_table
      4 from functools import update_wrapper, wraps
      5 from importlib import import_module
      6 from inspect import getmodule
      7 from io import BytesIO
      8 from pickle import Pickler
      9 from types import CellType, CodeType, FunctionType
     10 from typing import Any
     11 import ast
     12 import inspect
     13 import pickle
     14 
     15 
     16 def pickle_code_type(code: CodeType):
     17     return (
     18         unpickle_code_type,
     19         (
     20             code.co_argcount,
     21             code.co_posonlyargcount,
     22             code.co_kwonlyargcount,
     23             code.co_nlocals,
     24             code.co_stacksize,
     25             code.co_flags,
     26             code.co_code,
     27             code.co_consts,
     28             code.co_names,
     29             code.co_varnames,
     30             code.co_filename,
     31             code.co_name,
     32             code.co_qualname,
     33             code.co_firstlineno,
     34             code.co_linetable,
     35             code.co_exceptiontable,
     36             code.co_freevars,
     37             code.co_cellvars,
     38         ),
     39     )
     40 
     41 
     42 def unpickle_code_type(*args):
     43     return CodeType(*args)
     44 
     45 
     46 def pickle_cell_type(cell: CellType):
     47     return (unpickle_cell_type, (cell.cell_contents,))
     48 
     49 
     50 def unpickle_cell_type(*args):
     51     return CellType(*args)
     52 
     53 
     54 def pickle_function_type(f: FunctionType):
     55     mod = getmodule(f)
     56     return (
     57         unpickle_function_type,
     58         (
     59             f.__code__,
     60             mod.__name__ if mod is not None else None,
     61             (
     62                 tuple(CellType(cell.cell_contents) for cell in f.__closure__)
     63                 if f.__closure__
     64                 else None
     65             ),
     66         ),
     67     )
     68 
     69 
     70 def unpickle_function_type(code, mod_name, closure):
     71     return FunctionType(code, globals=import_module(mod_name).__dict__, closure=closure)
     72 
     73 
     74 class FunctionPickler(Pickler):
     75     dispatch_table = dispatch_table.copy()
     76     dispatch_table[CodeType] = pickle_code_type
     77     dispatch_table[CellType] = pickle_cell_type
     78 
     79     def reducer_override(self, obj):  # type: ignore
     80         if type(obj) is not FunctionType:
     81             return NotImplemented
     82         obj_mod = getmodule(obj)
     83         if obj_mod is None:
     84             return NotImplemented
     85         if obj.__name__ in dir(obj_mod):
     86             return NotImplemented
     87         return pickle_function_type(obj)
     88 
     89 
     90 def pickle_with(pickler_cls: type, obj: Any) -> bytes:
     91     i = BytesIO()
     92     pickler_cls(i).dump(obj)
     93     i.seek(0)
     94     return i.read()
     95 
     96 
     97 rerun_db_var: ContextVar[dict] = ContextVar("rerun_db")
     98 rerun_changes_var: ContextVar[list[tuple[Any, bytes]]] = ContextVar("rerun_changes")
     99 
    100 
    101 async def with_rerun_context(rerun_changes, f, /, *args, **kwargs):
    102     rerun_changes_var.set(rerun_changes)
    103     return await f(*args, **kwargs)
    104 
    105 
    106 def rewrite_rerun_if_changed(frame=None):
    107     def decorator(fn):
    108         s = inspect.getsource(fn).splitlines()
    109         i = 0
    110         while not s[i].startswith("async def"):
    111             i += 1
    112         s = s[i:]
    113         s = "\n".join(s)
    114         a = ast.parse(s).body[0]
    115 
    116         class RewriteCalls(ast.NodeTransformer):
    117             def visit_Expr(self, node: ast.Expr):
    118                 if (
    119                     isinstance(node.value, ast.Call)
    120                     and isinstance(node.value.func, ast.Call)
    121                     and isinstance(node.value.func.func, ast.Name)
    122                     and node.value.func.func.id == "rerun_if_changed"
    123                 ):
    124                     if len(node.value.func.args) == 0:
    125                         node.value.func.args.append(node.value.args[0])
    126                     out = ast.AsyncFunctionDef(
    127                         "_",
    128                         ast.arguments(),
    129                         [ast.Return(node.value.args[0])],
    130                         [node.value.func],
    131                     )
    132                     return out
    133                 return node
    134 
    135         a = ast.fix_missing_locations(RewriteCalls().visit(a))
    136         # print(ast.unparse(a))
    137         frame_ = frame if frame else inspect.currentframe().f_back  # type: ignore
    138         exec(ast.unparse(a), frame_.f_globals, frame_.f_locals, closure=fn.__closure__)  # type: ignore
    139         fn_ = list(frame_.f_locals.values())[-1]  # type: ignore
    140 
    141         fn_ = update_wrapper(fn_, fn)
    142         return fn_
    143 
    144     return decorator
    145 
    146 
    147 class _RunLaterNow: ...
    148 
    149 
    150 def rerun_if_changed(now: Any = _RunLaterNow, *, pickler_cls=FunctionPickler):
    151     def decorator(later):
    152         later_pkl = pickle_with(pickler_cls, later)
    153         if now is _RunLaterNow:
    154             raise RuntimeError(
    155                 "Should have been preprocessed away by the cache_conditionally macro"
    156             )
    157         else:
    158             rerun_changes_var.get().append((now, later_pkl))
    159 
    160     return decorator
    161 
    162 
    163 def rerun_if(f):
    164     @rerun_if_changed(False)
    165     async def _():
    166         return bool(await f())
    167 
    168 
    169 def rerun_always():
    170     @rerun_if_changed(False)
    171     async def _():
    172         return True
    173 
    174 
    175 def cache_conditionally(
    176     keys_fn=lambda *args, **kwargs: (args, tuple(sorted(kwargs.items()))),
    177     store_fn=lambda result, /, *_, **__: result,
    178     load_fn=lambda cached_result, /, *_, **__: cached_result,
    179     rewrite=True,
    180 ):
    181     def decorator(fn):
    182         if rewrite:
    183             fn = rewrite_rerun_if_changed(inspect.currentframe().f_back)(fn)  # type: ignore
    184 
    185         @wraps(fn)
    186         async def wrapped(*args, **kwargs):
    187             db = rerun_db_var.get()
    188             keys = keys_fn(*args, **kwargs)
    189             db_key = ("track", fn.__qualname__, keys)
    190             if db_key + ("result",) in db:
    191                 if db_key + ("rerun_changes",) not in db:
    192                     old_rerun_changes = []
    193                     db[db_key + ("rerun_changes",)] = old_rerun_changes
    194                 else:
    195                     old_rerun_changes = db[db_key + ("rerun_changes",)]
    196                 for old_val, f_pkl in old_rerun_changes:
    197                     try:
    198                         f_unpkled = pickle.loads(f_pkl)
    199                         val = await f_unpkled()
    200                         if old_val != val:
    201                             break
    202                     except BaseException:
    203                         break
    204                 else:
    205                     return load_fn(db[db_key + ("result",)], *args, **kwargs)
    206 
    207             rerun_changes = []
    208             result = await create_task(
    209                 with_rerun_context(rerun_changes, fn, *args, **kwargs),
    210                 context=copy_context(),
    211             )
    212             db[db_key + ("rerun_changes",)] = rerun_changes
    213             db[db_key + ("result",)] = store_fn(result, *args, **kwargs)
    214             return result
    215 
    216         return wrapped
    217 
    218     return decorator
    219 
    220 
    221 class Rerunner:
    222     def __init__(self, db_filename=b".makedb", db_file=None):
    223         if db_file:
    224             self.db_file = db_file
    225         else:
    226             self.db_file = open(db_filename, "a+b")
    227             self.db_file.seek(0)
    228 
    229     def __enter__(self):
    230         self.db_file.__enter__()
    231         try:
    232             self.db = pickle.load(self.db_file)
    233         except pickle.PickleError:
    234             self.db = dict()
    235         except EOFError:
    236             self.db = dict()
    237         self.var_tok = rerun_db_var.set(self.db)
    238         return self
    239 
    240     def __exit__(self, ty, exc, tb):
    241         rerun_db_var.reset(self.var_tok)
    242         if exc is None:
    243             self.db_file.seek(0)
    244             self.db_file.truncate(0)
    245             pickle.dump(self.db, self.db_file)
    246         self.db_file.__exit__(ty, exc, tb)