netket/netket

Error with VMC_SRt

PhilipVinc opened this issue · 0 comments

Discussed in https://github.com/orgs/netket/discussions/1642

Originally posted by aashmore November 8, 2023
Hello!

First of all, let me say thank you to all of the NetKet developers. What an amazing tool!

I'm trying to get VMC_SRt working (amazing that it has been added!). Unfortunately, it is currently giving me an error when I run the driver, while using standard SR works just fine.

The model is built on a discrete spin Hilbert space with complex parameters and complex output.

Using

gs = nk.driver.VMC(H, optimizer, variational_state=vstate, preconditioner=nk.optimizer.SR(diag_shift=0.1, holomorphic=False), holomorphic=False)

the model trains and converges with no errors (and reproduces known results from the literature).

Changing this to

import netket.experimental as nkx
gs = nkx.driver.VMC_SRt(H, optimizer, diag_shift=0.01, variational_state=vstate, jacobian_mode="complex")

leads to the error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-33-d2c8dffbbffd>](https://localhost:8080/#) in <cell line: 3>()
      1 log = nk.logging.RuntimeLog()
      2 
----> 3 gs.run(n_iter=100, out=log)
      4 
      5 ffn_energy = vstate.expect(H)

8 frames
[/usr/local/lib/python3.10/dist-packages/netket/driver/abstract_variational_driver.py](https://localhost:8080/#) in run(self, n_iter, out, obs, show_progress, save_params_every, write_every, step_size, callback)
    257             first_step = True
    258 
--> 259             for step in self.iter(n_iter, step_size):
    260                 log_data = self.estimate(obs)
    261                 self._log_additional_data(log_data, step)

[/usr/local/lib/python3.10/dist-packages/netket/driver/abstract_variational_driver.py](https://localhost:8080/#) in iter(self, n_steps, step)
    167         for _ in range(0, n_steps, step):
    168             for i in range(0, step):
--> 169                 dp = self._forward_and_backward()
    170                 if i == 0:
    171                     yield self.step_count

[/usr/local/lib/python3.10/dist-packages/netket/experimental/driver/vmc_srt.py](https://localhost:8080/#) in _forward_and_backward(self)
    244         )
    245 
--> 246         self._dp = self._unravel_params_fn(updates)
    247 
    248         return self._dp

    [... skipping hidden 13 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in unravel_pytree(treedef, unravel_list, flat)
     51 
     52 def unravel_pytree(treedef, unravel_list, flat):
---> 53   return tree_unflatten(treedef, unravel_list(flat))
     54 
     55 def _ravel_list(lst):

    [... skipping hidden 1 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in _unravel_list_single_dtype(indices, shapes, arr)
     76 def _unravel_list_single_dtype(indices, shapes, arr):
     77   chunks = jnp.split(arr, indices[:-1])
---> 78   return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
     79 
     80 def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):

[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in <listcomp>(.0)
     76 def _unravel_list_single_dtype(indices, shapes, arr):
     77   chunks = jnp.split(arr, indices[:-1])
---> 78   return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
     79 
     80 def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in meth(self, *args, **kwargs)
    731 def _forward_method_to_aval(name):
    732   def meth(self, *args, **kwargs):
--> 733     return getattr(self.aval, name).fun(self, *args, **kwargs)
    734   return meth
    735 

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in _reshape(a, order, *args)
    141   newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
    142   if order == "C":
--> 143     return lax.reshape(a, newshape, None)
    144   elif order == "F":
    145     dims = list(range(a.ndim)[::-1])

    [... skipping hidden 8 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in _reshape_shape_rule(operand, new_sizes, dimensions)
   3376       not math.prod(np.shape(operand)) == math.prod(new_sizes)):
   3377     msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
-> 3378     raise TypeError(msg.format(new_sizes, np.shape(operand)))
   3379   if dimensions is not None:
   3380     if set(dimensions) != set(range(np.ndim(operand))):

TypeError: reshape total size must be unchanged, got new_sizes (2, 3, 7) for shape (84,).

From my (amateur) reading of the trace, it looks like something is going wrong with chunking? The model itself is rather complicated, so if it isn't a problem with VMC_SRt itself, I'll put together a minimal working example and edit my question. I'm just a little confused as it why it would work with standard VMC but not VMC_SRt, since my understand is that the difference is just in how one computes a certain matrix.

Thank you in advance!