ets-labs/python-dependency-injector

Is there a way to create List[Provider[X]] from Provider[List[X]]?

Opened this issue · 0 comments

Hi, thanks for providing this library, I really like it!

Could you tell me if there is a way to create List[Provider[X]] from Provider[List[X]]?

I call this operation InsideOut.

def InsideOut(Provider[List[X]]) -> List[Provider[X]]: ...


import torch.utils.data
import torchvision.datasets
from dependency_injector import containers, providers

def input_transform():
    return transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))])

class AppContainer(containers.DeclarativeContainer):

    transform = providers.Singleton(input_transform)
    base_set = providers.Singleton(
        torchvision.datasets.FashionMNIST,
        './data', train=True,
        transform=transform, download=True)

    train, validate = providers.InsideOut(providers.Singleton(torch.utils.data.random_split, base_set, [0.8, 0.2]))

I've done some experiments with providers.Aggregate where you get a Provider[Dict], which is different from what I want. Also it was very difficult (maybe even impossible) to get pyright to stop throwing errors. Finally the cleanest solution I have a the moment is just keeper the list provider i.e.:

split = providers.Singleton(torch.utils.data.random_split, base_set, [0.8, 0.2]) # Provider[List]`

 train_loader = providers.Singleton(
     torch.utils.data.DataLoader,
     split.provided[0], batch_size=4, shuffle=True)

This works but I feel the first way would be cleaner.

Thanks,