RhodiumGroup/rhg_compute_tools

add simple caching utility

delgadom opened this issue · 3 comments

Some sort of easy caching utility would be great. Something like a decorator that can accept a filepattern and an overwrite argument on write.

Does something like this already exist? Also would be great to have this work with intake!

Proposed implementation

Lots to still work out here, but here's a stab:

import toolz
import functools

@toolz.curry
def cache_results(
        func,
        storage_dir=None,
        storage_pattern=None,
        create_directories=False,
        reader=pd.read_csv,
        writer=lambda x, fp, **kwargs: x.to_csv(fp, **kwargs),
        ext='.csv',
        read_kwargs=None,
        write_kwargs=None):
    read_kwargs = read_kwargs if read_kwargs is not None or {}
    write_kwargs = write_kwargs if write_kwargs is not None or {}

    @functools.wraps(func)
    def inner(*args, **kwargs, cache_path=None, overwrite=False):
        # TODO: covert args to kwargs using inspect to get arg_names
        kwargs = dict(**dict(zip(arg_names, args)), **kwargs)

        if cache_path is None:
            if storage_pattern is not None:
                cache_path = storage_pattern.format(**kwargs)
            elif storage_dir is not None:
                # TODO: create function some_hash_of_kwargs to handle different
                # orderings, usage of defaults, etc. stably
                cache_path = os.path.join(
                    storage_dir, some_hash_of_kwargs(kwargs) + ext)

            else:
                raise ArgumentError(
                   'must provide cache_path or define storage_dir '
                   'or storage_pattern at function decoration')

        if not overwrite:
            try:
                if create_directories:
                    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
                return reader(cache_path, **read_kwargs)
            except (IOError, OSError, ValueError):
                pass

        res = func(*args, **kwargs)
        writer(res, cache_path, **write_kwargs)
        return res
    return inner

This could be extended with a number of format-specific decorators quite easily

cache_in_netcdf = cache_results(
    reader=xr.open_dataset,
    writer=lambda ds, fp, **kwargs: ds.to_netcdf(fp, **kwargs),
    ext='.nc')

cache_in_zarr = cache_results(
    reader=xr.open_zarr,
    writer=lambda ds, fp, **kwargs: ds.to_zarr(fp, **kwargs),
    ext='.nc')

cache_in_parquet = cache_results(
    reader=pd.read_parquet,
    writer=lambda df, fp, **kwargs: df.to_parquet(fp, **kwargs),
    ext='.nc')

cache_in_netcdf = cache_results(
    reader=xr.open_dataset,
    writer=lambda ds, fp, **kwargs: ds.to_netcdf(fp, **kwargs),
    ext='.nc')

def cache_in_pickle(*args, **kwargs):
    import pickle

    def reader(fp, **kw):
         with open(fp, 'rb') as f:
            return pickle.load(f)

    def writer(data, fp, **kw):
        with open(fp, 'wb') as f:
            pickle.dump(data, f)

    return cache_results(*args, reader=reader, writer=writer, **kwargs)

Proposed usage

These could then be used in a variety of ways.

No arguments on decoration requires that a path be provided when called:

@cache_in_csv
def generate_random_df(lenth):
    return pd.DataFrame({'random': np.random.random(length)})

df = generate_random_df(4, cache_path='my_length_4_df.csv')

Providing a storage pattern allows you to set up a complex directory structure

@cache_in_netcdf(
        storage_pattern='/data/transformed/tasmax_squared/{rcp}/{model}/{year}.nc',
        create_directories=True)
def square_tasmax(rcp, model, year):
    tasmax_pattern = '/data/source/nasa-nex/tasmax/{rcp}/{model}/{year}.nc'
    return xr.open_dataset(tasmax_pattern.format(rcp=rcp, model=model, year=year)) ** 2

results = []
for rcp in ['rcp45', 'rcp85']:
    for model in ['ACCESS1-0', 'IPSL-ESM-CHEM']:
        for year in [2020, 2050, 2090]:
            # each of these results will be cached in a different file
            results.append(((rcp, model, year), square_tasmax(rcp, model, year)))

We can also pass reader/writer kwargs for more complex IO:

@cache_in_parquet(
        read_kwargs=dict(storage_options={'token': 'cloud'}),
        write_kwargs=dict(storage_options={'token': 'cloud'}))
def my_long_pandas_operation():
    time.sleep(4)
    return pd.DataFrame(np.random.random((6, 2)), columns=['a', 'b'])

df = my_long_pandas_operation(cache_path='gs://my_project/my_big_file.parquet')

Once the argument hashing in the TODO referenced above is implemented, we could handle arbitrarily complex argument calls, which will be hashed to form a unique, stable file name, e.g.:

@cache_in_pickle(storage_dir='/data/cached_noaa_api_calls')
def call_noaa_api(*args, **kwargs):
    return noaa_api.query(*args, **kwargs)

TODO

  • Look harder to see if something like this already exists!
  • Implement the hashing and arg/kwarg inspection features
  • Make sure the hashing implementation is stable and would not ever return wrong results (e.g. for changing defaults... ugh.
  • Maybe implement staleness redo criteria?

Did some preliminary research on other popular "caching" libraries. This so question was helpful.

First up, some takeaways

  1. This is not caching in the normal usage. Memcaching is by far the most popular use case, which only persists objects within a session, not across them. For examples see @functools.lru_cache, pycache
  2. Most caching modules rely on pickle or cpickle. While this is useful, it's not guaranteed to be stable over long periods for all types of objects, especially complex data structures like pandas DataFrames.
  3. We may want flexibility in terms of loaders/writers, filepaths, and arguments which are allowed to affect caching behavior. Some of these are supported by some libraries but not all.
  4. It's nice to be able to define these caches in directory structures and locations that make sense outside the context of the caching library, so that we can, e.g., cache tasmax_squared as part of our pipeline but also allow anyone to discover, inspect, and use this object.
  5. Getting this right for simultaneous read/write in the cloud is hard and requires some serious engineering. We should be careful to either use a real library that handles this for us or to avoid these situations (e.g. make sure we're only using pure functions and that inputs fully specify outputs). That said, we don't need to worry about simultaneous writes resulting in corrupted data, as google simply goes with the last-written object, whether over gcsfuse, gsutil, or google.cloud.storage.

Now, the other remotely feasible libraries

On-disk pickled caches

  • shelve - built-in method for pickle-based on-disk "dictionries". A bit more manual of a solution, but deserves mention.
  • pyfscache - potentially a great alternative for many of our very frequent API calls, e.g. to NOAA, which return little data and take a long time. Relies on cpickle and does not write objects as individual items that we could interpret outside the context of the cache, so probably not suitable for something like netcdf climate data.
  • cachetools
  • joblib.Memory - seems like a great option for pickle-based caching
  • bda.cache - may be wrapping pyfscache? Not really sure. Depends on cpickle for disk caching.

Server-based caching solutions

These would be a radically different approach to computing... but maybe?

oops.

Here's my implementation for caching NOAA API calls

from __future__ import absolute_import

import os
import toolz
import pickle
import inspect
import hashlib
import functools

from os.path import join
from sklearn.gaussian_process.kernels import RBF, _check_length_scale
from scipy.spatial.distance import pdist, squareform, cdist
import numpy as np
import pandas as pd
import shapely as shp
import shapely.geometry
import scipy.interpolate

import pyTC.settings


def get_error_type_indices(ftrs):
    io_indices = []
    fnf_indices = []
    other_indices = []
    for ftr in [f for f in ftrs if f.status == "error"]:
        if isinstance(ftr.exception(), FileNotFoundError):
            fnf_indices.append(ftrs.index(ftr))
        elif isinstance(ftr.exception(), OSError):
            io_indices.append(ftrs.index(ftr))
        else:
            other_indices.append(ftrs.index(ftr))

    return {"io": io_indices, "fnf": fnf_indices, "other": other_indices}


@toolz.curry
def cache_result_in_pickle(func, cache_dir=None, makedirs=False, error="raise"):
    """
    Caches the results of a function in the specified directory

    Uses the python pickle module to store the results of a
    function call in a directory, with file names set to the
    sha256 hash of the function's arguments. Pass `redo=True`
    or delete the contents of the directory to reset the cache.

    Because the results are cached based only on function
    parameters, it is important that the function not have any
    side effects.

    Note that all function arguments are hashed to derive a
    cached filename, and that any change to any input will
    produce a new cached file. Therefore, functions that
    depend on complex, frequently changing objects, especially
    settings objects, should not be cached. Instead, cache
    lower-level functions with a small list of simple,
    explicit arguments.

    Note also that cached files are not cleaned up
    automatically, and therefore changes in the arguments to a
    function will result in a new set of cached files being
    saved without removing the older files. This could result
    in cache storage creep unless the cache is periodically
    cleared. Clearing the cache based on file creation date
    can be an important part of cache maintenance.

    .. todo::

        replace this function with a more complete
        implementation, e.g. the one described in
        [GH RhodiumGroup/rhg_compute_tools#56](https://github.com/RhodiumGroup/rhg_compute_tools/issues/56).

    Parameters
    ----------
    func : function
        function to decorate. cannot have `redo` as an argument.
    cache_dir : str
        path to the root directory used in caching. If not
        provided, will use the `COASTAL_CACHE_DIR` attribute
        from `pyTC.settings.Settings()`, either one passed as `ps`
        to the wrapped func, or the default settings object if
        none is provided.
    makedirs : bool, optional
    
        

    Returns
    -------
    decorated : function
        Function, with cached results

    Examples
    --------

    .. code-block:: python

        >>> @cache_result_in_pickle(cache_dir=(tmpdir + '/cache'), makedirs=True)
        ... def long_running_func(i):
        ...     import time
        ...     time.sleep(0.1)
        ...     return i
        ...

    Initial calls will execute the function fully

    .. code-block:: python

        >>> long_running_func(1)  # > 0.1s
        1

    Subsequent calls will be much faster

    .. code-block:: python

        >>> long_running_func(1)  # << 0.1 s
        1

    Changing the arguments will result in re-evaluation

    .. code-block:: python

        >>> long_running_func(3)  # > 0.1s
        3

    Cached results are stored in the specified directory, under a
    subdirectory for each decorated function:

    .. code-block:: python

        >>> os.listdir(
        ...     tmpdir + '/cache/pyTC.utilities.long_running_func'
        ... )  # doctest: +NORMALIZE_WHITESPACE
        ...
        ['259ca9884c55ef7e909c0558978d73f915c6454d8e38bc576e8d48179138491a',
         '57630b792604ad1c663441890cda34728ffcb2c04d6b29dc720fd810318b61b6']

    Deleting these files would reset the cache without error. The cache can
    also be refreshed on a per-call basis by passing `redo=True` to the
    function call:

    .. code-block:: python

        >>> long_running_func(1, redo=True)  # > 0.1s
        1

    The parameters `'cache_dir'`, `'mkdirs'`, and `'error'` can also be
    overridden at function call:

    .. code-block:: python

        >>> long_running_func(1, cache_dir=(tmpdir + '/cache2'))
        1
        >>> os.listdir(
        ...     tmpdir + '/cache2/pyTC.utilities.long_running_func'
        ... )  # doctest: +NORMALIZE_WHITESPACE
        ...
        ['259ca9884c55ef7e909c0558978d73f915c6454d8e38bc576e8d48179138491a']

    """

    funcname = ".".join([func.__module__, func.__name__])
    sig = inspect.Signature.from_callable(func)

    default_cache_dir = cache_dir
    default_makedirs = makedirs
    default_error = error

    @functools.wraps(func)
    def inner(*args, redo=False, cache_dir=None, makedirs=None, error=None, **kwargs):

        if cache_dir is None:
            cache_dir = default_cache_dir

        if makedirs is None:
            makedirs = default_makedirs

        if error is None:
            error = default_error

        if error is None:
            error = "raise"

        error = str(error).lower()
        assert error in [
            "raise",
            "ignore",
            "remove",
        ], "error must be one of `'raise'`, `'ignore'`, or `'remove'`"

        if cache_dir is None:
            ps = kwargs.get("ps")

            if ps is None:
                ps = pyTC.settings.Settings()

            cache_dir = ps.DIR_DATA_CACHE

        bound_args = sig.bind(*args, **kwargs)
        bound_args.apply_defaults()

        sha = hashlib.sha256(pickle.dumps(bound_args))
        path = os.path.join(cache_dir, funcname, sha.hexdigest())

        if not redo:
            try:
                with open(path, "rb") as f:
                    return pickle.load(f)
            except (OSError, IOError):
                pass

        res = func(*args, **kwargs)

        try:
            if makedirs:
                os.makedirs(os.path.dirname(path), exist_ok=True)

            with open(path, "wb+") as f:
                pickle.dump(res, f)

        except (OSError, IOError, ValueError) as e:
            if error == "raise":
                raise
            elif error == "remove":
                try:
                    os.remove(path)
                except (IOError):
                    pass
                raise RuntimeError from e
            else:
                # case error == 'ignore'
                pass

        return res

    return inner