PythonOT/POT

question regarding applying transport

Closed this issue · 0 comments

hi,
I would like to identify a transport plan given two known input grids , and output grids and apply it to a new grid however it does not work:

import ot

def compute_transport_plan(tensor1, tensor2):
    """
    Compute the optimal transport plan from tensor1 to tensor2.
    """
    tensor1 = tensor1.flatten()
    tensor2 = tensor2.flatten()
    
    # Assuming uniform weights for the discrete distributions
    a = np.ones(len(tensor1)) / len(tensor1)
    b = np.ones(len(tensor2)) / len(tensor2)
    
    M = ot.dist(tensor1[:, None], tensor2[:, None], metric='euclidean')
    
    # Compute the optimal transport plan
    transport_plan = ot.emd(a, b, M)
    return transport_plan

def apply_transport_plan(tensor, transport_plan, target_shape):
  """
  Apply the transport plan to transform the tensor.
  """
  tensor = tensor.flatten()
  
  # Multiplication might work without reshape depending on tensor shapes
  transformed_tensor = np.dot(transport_plan, tensor)
  # Reshape to target shape only if necessary
  if len(transformed_tensor.shape) != len(target_shape):
      transformed_tensor = transformed_tensor.reshape(target_shape)
  
  return transformed_tensor

# Example usage
tensor1 = np.array([[1, 2, 3], [4, 5, 6]])
tensor2 = np.array([[7, 8, 9], [10, 11, 12]])
new_tensor = np.array([[13, 14, 15], [16, 17, 18]])


#  Compute Optimal Transport Plan
transport_plan = compute_transport_plan(tensor1, tensor2)
print(f"Optimal Transport Plan: \n{transport_plan}")

# Apply the Transport Plan to a New Tensor
transformed_tensor = apply_transport_plan(new_tensor, transport_plan, tensor2.shape)
print(f"Transformed Tensor: \n{transformed_tensor}")

I have this result :
Optimal Transport Plan:
[[0. 0. 0. 0. 0.16666667 0. ]
[0. 0.16666667 0. 0. 0. 0. ]
[0. 0. 0.16666667 0. 0. 0. ]
[0. 0. 0. 0. 0. 0.16666667]
[0.16666667 0. 0. 0. 0. 0. ]
[0. 0. 0. 0.16666667 0. 0. ]]

Transformed Tensor:
[[2.83333333 2.33333333 2.5 ]
[3. 2.16666667 2.66666667]]

I expected [19,20,21],[22,23,24]