PythonOT/POT

Parallelization problem for 3D tensor

Closed this issue · 1 comments

How can I leverage the code to compute the EMD btween a pair of 3D-tensor, and return a 4D-tensor as the EMD matrix?

Hello EMD is already optimized C++ and is not paralelizable directly but you can use jobli to comptue multiple emd in parallel. here is some code that was sent to me by the great @tomMoral

# %%

import numpy as np
import ot

from joblib import Parallel, delayed

# %%

n = 1000
k = 50

M = np.random.rand(k, n, n)
a = np.ones(n) / n

# %% loop


def emd(M, axes=None):
    return ot.emd(a, a, M)


R_loop = np.zeros((k, n, n))
ot.tic()
for i in range(k):
    R_loop[i] = emd(M[i])
ot.toc()


# %% numpy take/stack

def apply_across_axis(func, M, axis=0):
    return np.stack([
        func(M.take(i, axis))
        for i in range(M.shape[axis])
    ], axis=axis)


ot.tic()
R_numpy = apply_across_axis(emd, M, 0)
ot.toc()


# %% joblib?

def apply_across_axis_joblib(func, M, axis=0, n_jobs=4):
    res = Parallel(n_jobs=n_jobs, max_nbytes=None)(
        delayed(func)(M.take(i, axis))
        for i in range(M.shape[axis])
    )
    return np.stack(res, axis=axis)


R_joblib = apply_across_axis_joblib(emd, M[:4], 0)
ot.tic()
R_joblib = apply_across_axis_joblib(emd, M, 0)
ot.toc()

I think it can be easily adapted to your problem (I dont know what is going to happen with torch tensors though)