/llamass

LLAMASS is an arbitrary collection of tools I've put together to deal with motion data

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

llamass

LLAMASS is Loader for the AMASS dataset

Motivation

I'm writing this to use in a project working with pose data to train models on the AMASS dataset. I wanted to be able to install it in colab notebooks and elsewhere easily. Hopefully it's also useful for other people but be aware this is research code so not necessarily reliable.

Badges

PyPI version

example workflow

Install

License Agreement

Before using the AMASS dataset I'm expected to sign up to the license agreeement here. This package doesn't require any other code from MPI but visualization of pose data does, see below.

Install with pip

Requirements are handled by pip during the install but in a new environment I would install pytorch first to configure cuda as required for the system.

pip install llamass

For Visualization

There are demos for plotting available in the amass repo and in the smplx repo. I wrote a library based on these to plot without having to think about the betas, dmpls etc. It's called gaitplotlib and it can be installed from github:

pip install git+https://github.com/gngdb/gaitplotlib.git

How to use

Downloading the data

The AMASS website provides links to download the various parts of the AMASS dataset. Each is provided as a .tar.bz2 and I had to download them from the website by hand. Save all of these in a folder somehwere.

Unpacking the data

After installing llamass, it provides a console script to unpack the tar.bz2 files downloaded from the AMASS website:

fast_amass_unpack -n 4 --verify <dir with .tar.bz2 files> <dir to save unpacked data>

This will unpack the data in parallel in 4 jobs and provides a progress bar. The --verify flag will md5sum the directory the files are unpacked to and check it against what I found when I unpacked it. It'll also avoid unpacking tar files that have already been unpacked by looking for saved .hash files in the target directory. It's slower but more reliable and recovers from incomplete unpacking.

Alternatively, this can be access in the library using the llamass.core.unpack_body_models function:

import llamass.core

llamass.core.unpack_body_models("sample_data/", unpacked_directory, 4)
sample_data/sample.tar.bz2 extracting to /tmp/tmp06iwsfhu

Download Metadata

I've processed the files to find out how many frames are in each numpy archive unpacked when fast_amass_unpack is run. By default, the first time the AMASS Dataset object is asked for it's len it will look for a file containing this information in the specified AMASS directory. If it doesn't find it, it will recompute it and that can take 5 minutes.

Save 5 minutes by downloading it from this repository:

wget https://github.com/gngdb/llamass/raw/master/npz_file_lens.json.gz -P <dir to save unpacked data>

Train/val/test Split

details of script for splits goes here

Using the data

Once the data is unpacked it can be loaded by a PyTorch DataLoader directly using the llamass.core.AMASS Dataset class.

  • overlapping: whether the clips of frames taken from each file should be allowed to overlap
  • clip_length: how long should clips from each file be?
  • transform: a transformation function apply to all fields

It is an IterableDataset so it cannot be shuffled by the DataLoader. If shuffle=True the DataLoader will hit an error. However, the AMASS class itself implements shuffling and it can be enabled using shuffle=True at initialisation.

Also, in order to use more than one worker it is necessary to use the provided worker_init_fn in the DataLoader. This can also be accessed by using llamass.core.IterableLoader in place of DataLoader, and this is safer because using DataLoader without worker_init_fn will yield duplicate data when workers load from the same files.

import torch
from torch.utils.data import DataLoader

amass = llamass.core.AMASS(
    unpacked_directory,
    overlapping=False,
    clip_length=1,
    transform=torch.tensor,
    shuffle=False,
    seed=0,
)
# these are equivalent
amassloader = DataLoader(amass, batch_size=4, num_workers=2, worker_init_fn=llamass.core.worker_init_fn)
amassloader = llamass.core.IterableLoader(amass, batch_size=4, num_workers=2)

for data in amassloader:
    for k in data:
        print(k, data[k].size())
    break
poses torch.Size([4, 1, 156])
dmpls torch.Size([4, 1, 8])
trans torch.Size([4, 1, 3])
betas torch.Size([4, 1, 16])
gender torch.Size([4, 1])

Visualise Poses

poses = next(iter(llamass.core.IterableLoader(amass, batch_size=200, num_workers=2)))
poses = poses['poses'].squeeze()
# gaitplotlib
import numpy as np
import gaitplotlib.core
import matplotlib.pyplot as plt
plt.style.use('ggplot')

params = gaitplotlib.core.plottable(poses.numpy())

def plot_pose(pose_index, save_to=None):
    fig, axes = plt.subplots(1, 3, figsize=(10,6))

    for d, ax in enumerate(axes):
        dims_to_plot = [i for i in range(3) if i != d]
        joints, skeleton = params[pose_index]["joints"], params.skeleton
        j = joints[:, dims_to_plot]
        ax.scatter(*j.T, color="r", s=0.2)
        for bone in skeleton:
            a = j[bone[0]]
            b = j[bone[1]]
            x, y = list(zip(a, b))
            ax.plot(x, y, color="r", alpha=0.5)
        ax.axes.xaxis.set_ticklabels([])
        ax.axes.yaxis.set_ticklabels([])
        ax.set_aspect('equal', adjustable='box')
    if save_to is not None:
        plt.tight_layout()
        plt.savefig(save_to)
        plt.close()
    else:
        plt.show()
plot_pose(0)

png

# gaitplotlib
from pathlib import Path
import mediapy as media


animloc = Path(unpacked_directory)/'anim'
animloc.mkdir(exist_ok=True)

def get_frame(i, frameloc=animloc/'frame.jpeg'):
    plot_pose(i, save_to=frameloc)
    return media.read_image(frameloc)    
img_arr = get_frame(0)

with media.VideoWriter(animloc/'anim.gif', codec='gif', shape=img_arr.shape[:2], fps=10) as w:
    for i in range(0, params.vertices.shape[0], 10):
        frameloc = animloc/'frame.jpeg'
        plot_pose(i, save_to=frameloc)
        img_arr = media.read_image(frameloc)
        w.add_image(img_arr)

video = media.read_video(animloc/'anim.gif')
media.show_video(video, codec='gif')