deepmodeling/CrystalFormer

IndexError when indexing `atom_mask` in `sample_crystal` function

SchrodingersCattt opened this issue · 1 comments

Issue Description

When running the generate_and_visualize function with default settings in Bohrium notebook and Google Colab as the following code block, an IndexError is raised from the _index_to_gather function in jax._src.numpy.lax_numpy.

# ============= params to control the generation =============
spacegroup = 225  
elements = "Si O"   
temperature = 1.0 
seed = 42

# =============== generate and visualization =================
generate_and_visualize(spacegroup, elements, temperature, seed)

The completed error information is:

Generating with spacegroup=225, elements=Si O, temperature=1.0
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[6], line 8
      5 seed = 42          # 随机种子
      7 # =============== generate and visualization =================
----> 8 generate_and_visualize(spacegroup, elements, temperature, seed)

Cell In[4], line 20, in generate_and_visualize(spacegroup, elements, temperature, seed)
     18 key, subkey = jax.random.split(key)
     19 start_time = time()
---> 20 XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, spacegroup, None, atom_mask, top_p, temperature, temperature, args.use_foriloop)
     21 end_time = time()
     22 print("executation time:", end_time - start_time)

    [... skipping hidden 11 frame]

File /CrystalFormer/./src/sample.py:147, in sample_crystal(key, transformer, params, n_max, batchsize, atom_types, wyck_types, Kx, Kl, g, w_mask, atom_mask, top_p, temperature, T1, constraints)
    144 Z = jnp.zeros((batchsize, n_max))
    145 L = jnp.zeros((batchsize, n_max, Kl+2*6*Kl)) # we accumulate lattice params and sample lattice after
--> 147 key, W, A, X, Y, Z, L = jax.lax.fori_loop(0, n_max, body_fn, (key, W, A, X, Y, Z, L))
    149 M = mult_table[g-1, W]
    150 num_sites = jnp.sum(A!=0, axis=1)

    [... skipping hidden 12 frame]

File /CrystalFormer/./src/sample.py:83, in sample_crystal.<locals>.body_fn(i, state)
     80 a_logit = h_al[:, :atom_types]
     82 key, subkey = jax.random.split(key)
---> 83 a_logit = a_logit + jnp.where(atom_mask[i, :], 1e10, 0.0) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp)
     84 _temp = jax.lax.cond(i==0,
     85                         true_fun=lambda x: jnp.array(T1, dtype=float),
     86                         false_fun=lambda x: temperature,
     87                         operand=None)
     88 _a = sample_top_p(subkey, a_logit, top_p, _temp)  # use T1 for the first atom type

File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
    738 def op(self, *args):
--> 739   return getattr(self.aval, f"_{name}")(self, *args)

File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
    351 def _getitem(self, item):
--> 352   return lax_numpy._rewriting_take(self, item)

File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5616, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   5613       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
   5615 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 5616 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   5617                unique_indices, mode, fill_value)

File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5625, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   5622 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   5623             unique_indices, mode, fill_value):
   5624   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 5625   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   5626   y = arr
   5628   if fill_value is not None:

File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5733, in _index_to_gather(x_shape, idx, normalize_indices)
   5730 def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
   5731                      normalize_indices: bool = True) -> _Indexer:
   5732   # Remove ellipses and add trailing slice(None)s.
-> 5733   idx = _canonicalize_tuple_index(len(x_shape), idx)
   5735   # Check for scalar boolean indexing: this requires inserting extra dimensions
   5736   # before performing the rest of the logic.
   5737   scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)]

File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:6053, in _canonicalize_tuple_index(arr_ndim, idx, array_name)
   6051 num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
   6052 if num_dimensions_consumed > arr_ndim:
-> 6053   raise IndexError(
   6054       f"Too many indices for {array_name}: {num_dimensions_consumed} "
   6055       f"non-None/Ellipsis indices for dim {arr_ndim}.")
   6056 ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
   6057 ellipsis_index = next(ellipses, None)

IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

It seems that the error occurs due to an incorrect indexing operation on the atom_mask tensor inside the body_fn function of sample_crystal.

Would you please review the suggested fix? And let me know if you need any further assistance or have any additional questions. Thanks a lot for your time!

Sorry for late reply. It is due to the update of our codes. We want to control the elements in each step of sampling, so we extend the shape of atom_mask from (atom_types, ) to (n_max, atom_types). To do this, users can specify the different mask atoms in each step. However, this change caused some conflicts with the previous code.

I have fixed it in the Bohrium notebook and Google Colab. Feel free to contact me if you have any questions about CrystalFormer.