stefanonardo/pytorch-esn

mnist: RuntimeError: linalg.solve: A must be batches of square matrices, but they are 501 by 10 matrices

jabowery opened this issue · 1 comments

Running the example/mnist.py I get:

/torchesn/nn/echo_state_network.py", line 237, in fit
    W = torch.linalg.solve(self.XTy,
RuntimeError: linalg.solve: A must be batches of square matrices, but they are 501 by 10 matrices

Hi @jabowery ,
can you try using torch.cholesky_solve() instead of torch.linalg.solve()?

So, change the code from:

W = torch.linalg.solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device))[0].t()

to

W = torch.cholesky_solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device)).t()