mnist: RuntimeError: linalg.solve: A must be batches of square matrices, but they are 501 by 10 matrices
jabowery opened this issue · 1 comments
jabowery commented
Running the example/mnist.py I get:
/torchesn/nn/echo_state_network.py", line 237, in fit
W = torch.linalg.solve(self.XTy,
RuntimeError: linalg.solve: A must be batches of square matrices, but they are 501 by 10 matrices
stefanonardo commented
Hi @jabowery ,
can you try using torch.cholesky_solve()
instead of torch.linalg.solve()
?
So, change the code from:
W = torch.linalg.solve(self.XTy,
self.XTX + self.lambda_reg * torch.eye(
self.XTX.size(0), device=self.XTX.device))[0].t()
to
W = torch.cholesky_solve(self.XTy,
self.XTX + self.lambda_reg * torch.eye(
self.XTX.size(0), device=self.XTX.device)).t()