Tools to generate and use multi-object datasets. The datasets consist of images and a dictionary of labels, where each image is labeled with 1) the number of objects in it and 2) each object's attributes.
Using datasets only requires numpy
as datasets are .npz
.
Generating sprites requires scikit-image
. Tools for using the
datasets in PyTorch are provided, with usage examples.
- Either download one of the datasets in
generated/
, or generate a new one. - Place the
.npz
dataset in/path/to/data/
. pip install multiobject
- Usage in PyTorch:
from multiobject.pytorch import MultiObjectDataLoader, MultiObjectDataset dataset_path = '/path/to/data/some_dataset.npz' train_set = MultiObjectDataset(dataset_path, train=True) test_set = MultiObjectDataset(dataset_path, train=False) train_loader = MultiObjectDataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = MultiObjectDataLoader(test_set, batch_size=test_batch_size)
conda create --name multiobject python=3.7
conda activate multiobject
pip install -r requirements.txt
CUDA_VISIBLE_DEVICES=0 python demo_vae.py
CUDA_VISIBLE_DEVICES=0 python demo_count.py
Datasets are available as .npz
files in ./generated/
.
dSprites1
Binary RGB images with monochromatic dSprites on a black canvas. Sprites can overlap (sum and clip).
n images | size | max object size |
objects per image |
sprite colors | file size |
---|---|---|---|---|---|
100k | 64x64 | 18x18 | 1 | 7 | 10.6 MB |
100k | 64x64 | 28x28 | 1 | 7 | 12.4 MB |
100k | 64x64 | 18x18 | 0–2 (uniformly) | 7 | 11.0 MB |
Binary 64x64 single-channel images with MNIST digits on a black canvas. Digits are rescaled to 18x18 and binarized, and they can overlap (sum and clip). Only digits from the MNIST training set are used (60k).
n images | size | max object size |
objects per image |
file size |
---|---|---|---|---|
100k | 64x64 | 18x18 | 1 | 4.5 MB |
100k | 64x64 | 18x18 | 0–2 (uniformly) | 4.8 MB |
-
Clone this repo.
-
See requirements, or set up a virtual environment as follows:
conda create --name multiobject python=3.7 conda activate multiobject pip install -r requirements.txt
-
Optional: generate a new type of sprites:
- create a file
sprites/xyz.py
containing a functiongenerate_xyz()
, where "xyz" denotes the new sprite type - in
generate_dataset.py
, add a call togenerate_xyz()
to generate the correct sprites, and add'xyz'
to the list of supported sprites
- create a file
-
Call
generate_dataset.py
with the desired sprite type as--type
argument. Example:python generate_dataset.py --type dsprites
The sprite attributes are managed automatically when generating a dataset from a set of sprites that have per-sprite labels. However, since they are dataset-specific, they have to be defined when creating the sprites.
Note. For now, the following parameters have to be customized in generate_dataset.py
directly:
- probability distribution over number of objects
- image size
- sprite size
- dataset size
- whether sprites can overlap
To generate datasets:
numpy==1.18.1
matplotlib==3.1.2
scikit_image==0.16.2
tqdm==4.41.1
pillow==7.0.0
To run the examples or use the pytorch tools:
torch==1.4.0
torchvision==0.5.0
1 This is actually an extension of the original dSprites dataset to many objects and to color images. ↩