weiji14/zen3geo

function docstrings for torchdata methods do not show in notebook with ??

Closed this issue ยท 4 comments

Describe the bug
When I run the jupyter help magic on a zen3geo method I get a docstring for the general partial function

dp.search_for_pystac_item??
Signature:   dp.search_for_pystac_item(*args, **kwargs)
Type:        partial
String form: functools.partial(<function IterDataPipe.register_datapipe_as_function.<locals>.class_function at 0x14b820f70>, <class 'zen3geo.datapipes.pystac_client.PySTACAPISearcherIterDataPipe'>, False, IterableWrapperIterDataPipe)
File:        ~/.pyenv/versions/3.10.6/lib/python3.10/functools.py
Source:     
class partial:
    """New function with partial application of the given arguments
    and keywords.
    """

    __slots__ = "func", "args", "keywords", "__dict__", "__weakref__"

    def __new__(cls, func, /, *args, **keywords):
        if not callable(func):
            raise TypeError("the first argument must be callable")

I expect to see the docstring for the zen3geo method so I know what args to supply (like how to pass auth credentials to use a STAC API). The actual docstring is here: https://zen3geo.readthedocs.io/en/latest/_modules/zen3geo/datapipes/pystac_client.html?highlight=search_for_pystac_item#

Expected behavior
Is there a way to register the correct docstring to a torchdata method? instead of the docstring for partial?

Yeah, this is one downside of using torchdata's functional form, and it's actually the same for all the official torch DataPipes upstream in https://github.com/pytorch/data which uses the @functional_datapipe decorator, see https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/data/datapipes/_decorator.py#L11-L38. E.g. if you do dp.map(), it will also show the partial function.

The class-based form is documented though, e.g. help(zen3geo.datapipes.PySTACAPISearcher), but you'll need to know how the functional-form and class-form maps to each other, which requires tab-completing from zen3geo.datapipes, or looking at the online API docs ๐Ÿ™‚

I'm aware of https://docs.python.org/3/library/functools.html#functools.wraps (see also https://stackoverflow.com/questions/308999/what-does-functools-wraps-do) which can be set as a decorator to 'copy' the documentation from a wrapped function to a wrapper function, but not sure if it works on a Python class (which the DataPipes are written in).

FYI, I've reported this upstream to torchdata at pytorch/data#792, let's see what the response is.

Thanks for the explanation and posting the other issue.

I've fixed the bug upstream as mentioned at pytorch/data#792 (comment), and using the Pytorch nightly builds (e.g. with pip install --pre torch torchdata --index-url https://download.pytorch.org/whl/nightly/cu121) should show the documentation properly. On torch=2.1.0.dev20230519+cu121 and torchdata=0.7.0.dev20230519, I've confirmed that help(dp.search_for_pystac_item) shows a better docstring when running this code

import torchdata
import zen3geo

dp = torchdata.datapipes.iter.IterableWrapper(iterable=["abc", "def"])
help(dp.search_for_pystac_item)

the output should look something like this:

Help on partial in module zen3geo.datapipes.pystac_client:

functools.partial(functools.partial(<function It...rDataPipe'>, False), IterableWrapperIterDataPipe)
    Takes dictionaries containing a STAC API query (as long as the parameters
    are understood by :py:meth:`pystac_client.Client.search`) and yields
    :py:class:`pystac_client.ItemSearch` objects (functional name:
    ``search_for_pystac_item``).
    
    Parameters
    ----------
    source_datapipe : IterDataPipe[dict]
        A DataPipe that contains STAC API query parameters in the form of a
        Python dictionary to pass to :py:meth:`pystac_client.Client.search`.
        For example:
    
        - **bbox** -  A list, tuple, or iterator representing a bounding box of
          2D or 3D coordinates. Results will be filtered to only those
          intersecting the bounding box.
        - **datetime** - Either a single datetime or datetime range used to
          filter results. You may express a single datetime using a
          :py:class:`datetime.datetime` instance, a
          `RFC 3339-compliant <https://tools.ietf.org/html/rfc3339>`_
          timestamp, or a simple date string.
        - **collections** - List of one or more Collection IDs or
          :py:class:`pystac.Collection` instances. Only Items in one of the
          provided Collections will be searched.
    
    catalog_url : str
        The URL of a STAC Catalog.
    
    kwargs : Optional
        Extra keyword arguments to pass to
        :py:meth:`pystac_client.Client.open`. For example:
    
        - **headers** - A dictionary of additional headers to use in all
          requests made to any part of this Catalog/API.
        - **parameters** - Optional dictionary of query string parameters to
          include in all requests.
        - **modifier** - A callable that modifies the children collection and
          items returned by this Client. This can be useful for injecting
          authentication parameters into child assets to access data from
          non-public sources.
    
    Yields
    ------
    item_search : pystac_client.ItemSearch
        A :py:class:`pystac_client.ItemSearch` object instance that represents
        a deferred query to a STAC search endpoint as described in the
        `STAC API - Item Search spec <https://github.com/radiantearth/stac-api-spec/tree/main/item-search>`_.
    
    Raises
    ------
    ModuleNotFoundError
        If ``pystac_client`` is not installed. See
        :doc:`install instructions for pystac-client <pystac_client:index>`,
        (e.g. via ``pip install pystac-client``) before using this class.
    
    Example
    -------
    >>> import pytest
    >>> pystac_client = pytest.importorskip("pystac_client")
    ...
    >>> from torchdata.datapipes.iter import IterableWrapper
    >>> from zen3geo.datapipes import PySTACAPISearcher
    ...
    >>> # Peform STAC API query using DataPipe
    >>> query = dict(
    ...     bbox=[174.5, -41.37, 174.9, -41.19],
    ...     datetime=["2012-02-20T00:00:00Z", "2022-12-22T00:00:00Z"],
    ...     collections=["cop-dem-glo-30"],
    ... )
    >>> dp = IterableWrapper(iterable=[query])
    >>> dp_pystac_client = dp.search_for_pystac_item(
    ...     catalog_url="https://planetarycomputer.microsoft.com/api/stac/v1",
    ...     # modifier=planetary_computer.sign_inplace,
    ... )
    >>> # Loop or iterate over the DataPipe stream
    >>> it = iter(dp_pystac_client)
    >>> stac_item_search = next(it)
    >>> stac_items = list(stac_item_search.items())
    >>> stac_items
    [<Item id=Copernicus_DSM_COG_10_S42_00_E174_00_DEM>]
    >>> stac_items[0].properties  # doctest: +NORMALIZE_WHITESPACE
    {'gsd': 30,
     'datetime': '2021-04-22T00:00:00Z',
     'platform': 'TanDEM-X',
     'proj:epsg': 4326,
     'proj:shape': [3600, 3600],
     'proj:transform': [0.0002777777777777778,
      0.0,
      173.9998611111111,
      0.0,
      -0.0002777777777777778,
      -40.99986111111111]}

Closing as done ๐Ÿ˜Ž