PythonOT/POT

How is transport plan (couplings) retrieved from `ot.da.LinearGWTransport()`?

ttsesm opened this issue · 0 comments

Hi,

I am trying to use the ot.da.LinearGWTransport() function as described in the POT documentation as follows:

def match_gaussians(Cs, Ms, ct, mt, verbose=False):
    """Match 3D Gaussians to 2D Gaussians using Linear Gromov-Wasserstein Transport."""
    # Compute cost matrices
    C3D, C2D = compute_cost_matrices(Cs, Ms, ct, mt)

    # Initialize and fit the LinearGWTransport model
    gw = ot.da.LinearGWTransport(log=verbose)
    gw.fit(Xs=Ms, Xt=mt, ys=C3D, yt=C2D)

    # Get the transport plan
    transport_plan = gw.coupling # <------------------------------------- This doesn't exist

    # Get the linear operator (projection matrix)
    projection_matrix = gw.L # <------------------------------------- This doesn't exist

    return projection_matrix, transport_plan

where Cs, Ms and ct, mt are my 3D and 2D gaussians respectively (i.e. covariance and mean values).

However, I am not sure how to retrieve back the transport plan and the projection matrix from the fitted distribution. From what I've noticed I can get back the A and B matrices but I am not sure how these are related to the transport plan and projection matrix.

I would appreciated if someone has an idea and/or provide some feedback.

Thanks.