IndexError when indexing `atom_mask` in `sample_crystal` function
SchrodingersCattt opened this issue · 1 comments
Issue Description
When running the generate_and_visualize
function with default settings in Bohrium notebook and Google Colab as the following code block, an IndexError
is raised from the _index_to_gather
function in jax._src.numpy.lax_numpy
.
# ============= params to control the generation =============
spacegroup = 225
elements = "Si O"
temperature = 1.0
seed = 42
# =============== generate and visualization =================
generate_and_visualize(spacegroup, elements, temperature, seed)
The completed error information is:
Generating with spacegroup=225, elements=Si O, temperature=1.0
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[6], line 8
5 seed = 42 # 随机种子
7 # =============== generate and visualization =================
----> 8 generate_and_visualize(spacegroup, elements, temperature, seed)
Cell In[4], line 20, in generate_and_visualize(spacegroup, elements, temperature, seed)
18 key, subkey = jax.random.split(key)
19 start_time = time()
---> 20 XYZ, A, W, M, L = sample_crystal(subkey, transformer, params, args.n_max, n_sample, args.atom_types, args.wyck_types, args.Kx, args.Kl, spacegroup, None, atom_mask, top_p, temperature, temperature, args.use_foriloop)
21 end_time = time()
22 print("executation time:", end_time - start_time)
[... skipping hidden 11 frame]
File /CrystalFormer/./src/sample.py:147, in sample_crystal(key, transformer, params, n_max, batchsize, atom_types, wyck_types, Kx, Kl, g, w_mask, atom_mask, top_p, temperature, T1, constraints)
144 Z = jnp.zeros((batchsize, n_max))
145 L = jnp.zeros((batchsize, n_max, Kl+2*6*Kl)) # we accumulate lattice params and sample lattice after
--> 147 key, W, A, X, Y, Z, L = jax.lax.fori_loop(0, n_max, body_fn, (key, W, A, X, Y, Z, L))
149 M = mult_table[g-1, W]
150 num_sites = jnp.sum(A!=0, axis=1)
[... skipping hidden 12 frame]
File /CrystalFormer/./src/sample.py:83, in sample_crystal.<locals>.body_fn(i, state)
80 a_logit = h_al[:, :atom_types]
82 key, subkey = jax.random.split(key)
---> 83 a_logit = a_logit + jnp.where(atom_mask[i, :], 1e10, 0.0) # enhance the probability of masked atoms (do not need to normalize since we only use it for sampling, not computing logp)
84 _temp = jax.lax.cond(i==0,
85 true_fun=lambda x: jnp.array(T1, dtype=float),
86 false_fun=lambda x: temperature,
87 operand=None)
88 _a = sample_top_p(subkey, a_logit, top_p, _temp) # use T1 for the first atom type
File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:739, in _forward_operator_to_aval.<locals>.op(self, *args)
738 def op(self, *args):
--> 739 return getattr(self.aval, f"_{name}")(self, *args)
File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:352, in _getitem(self, item)
351 def _getitem(self, item):
--> 352 return lax_numpy._rewriting_take(self, item)
File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5616, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
5613 return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
5615 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 5616 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
5617 unique_indices, mode, fill_value)
File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5625, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
5622 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
5623 unique_indices, mode, fill_value):
5624 idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 5625 indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
5626 y = arr
5628 if fill_value is not None:
File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5733, in _index_to_gather(x_shape, idx, normalize_indices)
5730 def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
5731 normalize_indices: bool = True) -> _Indexer:
5732 # Remove ellipses and add trailing slice(None)s.
-> 5733 idx = _canonicalize_tuple_index(len(x_shape), idx)
5735 # Check for scalar boolean indexing: this requires inserting extra dimensions
5736 # before performing the rest of the logic.
5737 scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)]
File /opt/mamba/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:6053, in _canonicalize_tuple_index(arr_ndim, idx, array_name)
6051 num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
6052 if num_dimensions_consumed > arr_ndim:
-> 6053 raise IndexError(
6054 f"Too many indices for {array_name}: {num_dimensions_consumed} "
6055 f"non-None/Ellipsis indices for dim {arr_ndim}.")
6056 ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
6057 ellipsis_index = next(ellipses, None)
IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.
It seems that the error occurs due to an incorrect indexing operation on the atom_mask
tensor inside the body_fn
function of sample_crystal
.
Would you please review the suggested fix? And let me know if you need any further assistance or have any additional questions. Thanks a lot for your time!
Sorry for late reply. It is due to the update of our codes. We want to control the elements in each step of sampling, so we extend the shape of atom_mask
from (atom_types, )
to (n_max, atom_types)
. To do this, users can specify the different mask atoms in each step. However, this change caused some conflicts with the previous code.
I have fixed it in the Bohrium notebook and Google Colab. Feel free to contact me if you have any questions about CrystalFormer
.