SimplexDomain error when n_vertices > 3
ndem0 opened this issue · 3 comments
ndem0 commented
Describe the bug
SimplexDomain produces error when more than 3 vertices are passed.
To Reproduce
spatial_domain2 = SimplexDomain(
[
LabelTensor(torch.tensor([[ 0., -2.]]), labels=["x", "y"]),
LabelTensor(torch.tensor([[-.5, -.5]]), labels=["x", "y"]),
LabelTensor(torch.tensor([[-2., 0.]]), labels=["x", "y"]),
LabelTensor(torch.tensor([[-.5, .5]]), labels=["x", "y"]),
]
)
pts = spatial_domain2.sample(100)
Output
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[47], line 23
2 spatial_domain = SimplexDomain(
3 [
4 LabelTensor(torch.tensor([[0, 0]]), labels=["x", "y"]),
(...)
7 ]
8 )
10 spatial_domain2 = SimplexDomain(
11 [
12 LabelTensor(torch.tensor([[ 0., -2.]]), labels=["x", "y"]),
(...)
21 ]
22 )
---> 23 pts = spatial_domain2.sample(100)
24 fig, ax = plt.subplots()
25 plot_scatter(ax, pts, 'Simplex Domain')
File ~/.local/lib/python3.9/site-packages/pina/geometry/simplex.py:216, in SimplexDomain.sample(self, n, mode, variables)
214 sample_pts = self._sample_boundary_randomly(n)
215 else:
--> 216 sample_pts = self._sample_interior_randomly(n, variables)
218 else:
219 raise NotImplementedError(f"mode={mode} is not implemented.")
File ~/.local/lib/python3.9/site-packages/pina/geometry/simplex.py:155, in SimplexDomain._sample_interior_randomly(self, n, variables)
150 while len(sampled_points) < n:
151 sampled_point = self._cartesian_bound.sample(
152 n=1, mode="random", variables=variables
153 )
--> 155 if self.is_inside(sampled_point, self._sample_surface):
156 sampled_points.append(sampled_point)
157 return torch.cat(sampled_points, dim=0)
File ~/.local/lib/python3.9/site-packages/pina/geometry/simplex.py:116, in SimplexDomain.is_inside(self, point, check_border)
113 point_shift = point.T - (self._vertices_matrix.T)[:, None, -1]
115 # compute barycentric coordinates
--> 116 lambda_ = torch.linalg.solve(self._vectors_shifted * 1.0, point_shift * 1.0)
117 lambda_1 = 1.0 - torch.sum(lambda_)
118 lambdas = torch.vstack([lambda_, lambda_1])
File ~/.local/lib/python3.9/site-packages/torch/_tensor.py:1386, in Tensor.__torch_function__(cls, func, types, args, kwargs)
1383 return NotImplemented
1385 with _C.DisableTorchFunctionSubclass():
-> 1386 ret = func(*args, **kwargs)
1387 if func in get_default_nowrap_functions():
1388 return ret
RuntimeError: linalg.solve: A must be batches of square matrices, but they are 2 by 3 matrices
Additional context
Version 0.1
dario-coscia commented
I can work on it, I am already aware of the bug. It should be a short fix.
dario-coscia commented
Hi @ndem0 this is a bug in the init constructor. Indeed, a N-dimensional simplex is composed of N+1 vercices, where each vertex is of dimension N. In your case, you are building a N=2 dimensional simplex (a triangle), each vertex is N=2 dimensional, and there must be N+1 vertices. Let's fix it in the init.
dario-coscia commented
Solved in #196 closing the issue.