Error with VMC_SRt
PhilipVinc opened this issue · 0 comments
Discussed in https://github.com/orgs/netket/discussions/1642
Originally posted by aashmore November 8, 2023
Hello!
First of all, let me say thank you to all of the NetKet developers. What an amazing tool!
I'm trying to get VMC_SRt working (amazing that it has been added!). Unfortunately, it is currently giving me an error when I run the driver, while using standard SR works just fine.
The model is built on a discrete spin Hilbert space with complex parameters and complex output.
Using
gs = nk.driver.VMC(H, optimizer, variational_state=vstate, preconditioner=nk.optimizer.SR(diag_shift=0.1, holomorphic=False), holomorphic=False)
the model trains and converges with no errors (and reproduces known results from the literature).
Changing this to
import netket.experimental as nkx
gs = nkx.driver.VMC_SRt(H, optimizer, diag_shift=0.01, variational_state=vstate, jacobian_mode="complex")
leads to the error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-33-d2c8dffbbffd>](https://localhost:8080/#) in <cell line: 3>()
1 log = nk.logging.RuntimeLog()
2
----> 3 gs.run(n_iter=100, out=log)
4
5 ffn_energy = vstate.expect(H)
8 frames
[/usr/local/lib/python3.10/dist-packages/netket/driver/abstract_variational_driver.py](https://localhost:8080/#) in run(self, n_iter, out, obs, show_progress, save_params_every, write_every, step_size, callback)
257 first_step = True
258
--> 259 for step in self.iter(n_iter, step_size):
260 log_data = self.estimate(obs)
261 self._log_additional_data(log_data, step)
[/usr/local/lib/python3.10/dist-packages/netket/driver/abstract_variational_driver.py](https://localhost:8080/#) in iter(self, n_steps, step)
167 for _ in range(0, n_steps, step):
168 for i in range(0, step):
--> 169 dp = self._forward_and_backward()
170 if i == 0:
171 yield self.step_count
[/usr/local/lib/python3.10/dist-packages/netket/experimental/driver/vmc_srt.py](https://localhost:8080/#) in _forward_and_backward(self)
244 )
245
--> 246 self._dp = self._unravel_params_fn(updates)
247
248 return self._dp
[... skipping hidden 13 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in unravel_pytree(treedef, unravel_list, flat)
51
52 def unravel_pytree(treedef, unravel_list, flat):
---> 53 return tree_unflatten(treedef, unravel_list(flat))
54
55 def _ravel_list(lst):
[... skipping hidden 1 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in _unravel_list_single_dtype(indices, shapes, arr)
76 def _unravel_list_single_dtype(indices, shapes, arr):
77 chunks = jnp.split(arr, indices[:-1])
---> 78 return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
79
80 def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in <listcomp>(.0)
76 def _unravel_list_single_dtype(indices, shapes, arr):
77 chunks = jnp.split(arr, indices[:-1])
---> 78 return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
79
80 def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in meth(self, *args, **kwargs)
731 def _forward_method_to_aval(name):
732 def meth(self, *args, **kwargs):
--> 733 return getattr(self.aval, name).fun(self, *args, **kwargs)
734 return meth
735
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in _reshape(a, order, *args)
141 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
142 if order == "C":
--> 143 return lax.reshape(a, newshape, None)
144 elif order == "F":
145 dims = list(range(a.ndim)[::-1])
[... skipping hidden 8 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in _reshape_shape_rule(operand, new_sizes, dimensions)
3376 not math.prod(np.shape(operand)) == math.prod(new_sizes)):
3377 msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
-> 3378 raise TypeError(msg.format(new_sizes, np.shape(operand)))
3379 if dimensions is not None:
3380 if set(dimensions) != set(range(np.ndim(operand))):
TypeError: reshape total size must be unchanged, got new_sizes (2, 3, 7) for shape (84,).
From my (amateur) reading of the trace, it looks like something is going wrong with chunking? The model itself is rather complicated, so if it isn't a problem with VMC_SRt itself, I'll put together a minimal working example and edit my question. I'm just a little confused as it why it would work with standard VMC but not VMC_SRt, since my understand is that the difference is just in how one computes a certain matrix.
Thank you in advance!