mathLab/PINA

SimplexDomain error when n_vertices > 3

ndem0 opened this issue · 3 comments

ndem0 commented

Describe the bug
SimplexDomain produces error when more than 3 vertices are passed.

To Reproduce

spatial_domain2 = SimplexDomain(
                    [
                        LabelTensor(torch.tensor([[ 0., -2.]]), labels=["x", "y"]),
                        LabelTensor(torch.tensor([[-.5, -.5]]), labels=["x", "y"]),
                        LabelTensor(torch.tensor([[-2.,  0.]]), labels=["x", "y"]),
                        LabelTensor(torch.tensor([[-.5,  .5]]), labels=["x", "y"]),
                    ]
                )
pts = spatial_domain2.sample(100)

Output

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[47], line 23
      2 spatial_domain = SimplexDomain(
      3                     [
      4                         LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]),
   (...)
      7                     ]
      8                 )
     10 spatial_domain2 = SimplexDomain(
     11                     [
     12                         LabelTensor(torch.tensor([[ 0., -2.]]), labels=["x", "y"]),
   (...)
     21                     ]
     22                 )
---> 23 pts = spatial_domain2.sample(100)
     24 fig, ax = plt.subplots()
     25 plot_scatter(ax, pts, 'Simplex Domain')

File ~/.local/lib/python3.9/site-packages/pina/geometry/simplex.py:216, in SimplexDomain.sample(self, n, mode, variables)
    214         sample_pts = self._sample_boundary_randomly(n)
    215     else:
--> 216         sample_pts = self._sample_interior_randomly(n, variables)
    218 else:
    219     raise NotImplementedError(f"mode={mode} is not implemented.")

File ~/.local/lib/python3.9/site-packages/pina/geometry/simplex.py:155, in SimplexDomain._sample_interior_randomly(self, n, variables)
    150 while len(sampled_points) < n:
    151     sampled_point = self._cartesian_bound.sample(
    152         n=1, mode="random", variables=variables
    153     )
--> 155     if self.is_inside(sampled_point, self._sample_surface):
    156         sampled_points.append(sampled_point)
    157 return torch.cat(sampled_points, dim=0)

File ~/.local/lib/python3.9/site-packages/pina/geometry/simplex.py:116, in SimplexDomain.is_inside(self, point, check_border)
    113 point_shift = point.T - (self._vertices_matrix.T)[:, None, -1]
    115 # compute barycentric coordinates
--> 116 lambda_ = torch.linalg.solve(self._vectors_shifted * 1.0, point_shift * 1.0)
    117 lambda_1 = 1.0 - torch.sum(lambda_)
    118 lambdas = torch.vstack([lambda_, lambda_1])

File ~/.local/lib/python3.9/site-packages/torch/_tensor.py:1386, in Tensor.__torch_function__(cls, func, types, args, kwargs)
   1383     return NotImplemented
   1385 with _C.DisableTorchFunctionSubclass():
-> 1386     ret = func(*args, **kwargs)
   1387     if func in get_default_nowrap_functions():
   1388         return ret

RuntimeError: linalg.solve: A must be batches of square matrices, but they are 2 by 3 matrices

Additional context
Version 0.1

I can work on it, I am already aware of the bug. It should be a short fix.

Hi @ndem0 this is a bug in the init constructor. Indeed, a N-dimensional simplex is composed of N+1 vercices, where each vertex is of dimension N. In your case, you are building a N=2 dimensional simplex (a triangle), each vertex is N=2 dimensional, and there must be N+1 vertices. Let's fix it in the init.

Solved in #196 closing the issue.