How is transport plan (couplings) retrieved from `ot.da.LinearGWTransport()`?
ttsesm opened this issue · 0 comments
ttsesm commented
Hi,
I am trying to use the ot.da.LinearGWTransport()
function as described in the POT documentation as follows:
def match_gaussians(Cs, Ms, ct, mt, verbose=False):
"""Match 3D Gaussians to 2D Gaussians using Linear Gromov-Wasserstein Transport."""
# Compute cost matrices
C3D, C2D = compute_cost_matrices(Cs, Ms, ct, mt)
# Initialize and fit the LinearGWTransport model
gw = ot.da.LinearGWTransport(log=verbose)
gw.fit(Xs=Ms, Xt=mt, ys=C3D, yt=C2D)
# Get the transport plan
transport_plan = gw.coupling # <------------------------------------- This doesn't exist
# Get the linear operator (projection matrix)
projection_matrix = gw.L # <------------------------------------- This doesn't exist
return projection_matrix, transport_plan
where Cs
, Ms
and ct
, mt
are my 3D and 2D gaussians respectively (i.e. covariance and mean values).
However, I am not sure how to retrieve back the transport plan and the projection matrix from the fitted distribution. From what I've noticed I can get back the A
and B
matrices but I am not sure how these are related to the transport plan and projection matrix.
I would appreciated if someone has an idea and/or provide some feedback.
Thanks.