atomistic-machine-learning/schnetpack

ISO17 data load bug

Closed this issue · 2 comments

Looks like there is a bug for loading iso17data to torch tensors
https://colab.research.google.com/drive/1LkySYJWJoFScx1w4DKD9SJ5XOXtsmYU1?usp=sharing

! pip install schnetpack

from schnetpack.data import ASEAtomsData
from schnetpack.datasets import ISO17
from schnetpack.transform import ASENeighborList

iso17data = ISO17(
    './iso17.db',
    fold = 'reference_eq',
    batch_size=10,
    num_train=11000,
    num_val=10000,
    transforms=[ASENeighborList(cutoff=5.)]
)
iso17data.prepare_data()
iso17data.setup()
print('Number of reference calculations:', len(iso17data.dataset))
print('Number of train data:', len(iso17data.train_dataset))
print('Number of validation data:', len(iso17data.val_dataset))
print('Number of test data:', len(iso17data.test_dataset))
print('Available properties:')

for p in iso17data.dataset.available_properties:
    print('-', p)

iso17data.dataset[0]    

throws

KeyError                                  Traceback (most recent call last)
[<ipython-input-1-2b61099d20f0>](https://localhost:8080/#) in <cell line: 26>()
     24     print('-', p)
     25 
---> 26 iso17data.dataset[0]

1 frames
[/usr/local/lib/python3.10/dist-packages/schnetpack/data/atoms.py](https://localhost:8080/#) in _get_properties(self, conn, idx, load_properties, load_structure)
    345         for pname in load_properties:
    346             properties[pname] = (
--> 347                 torch.tensor(row.data[pname].copy()) * self.conversions[pname]
    348             )
    349 

KeyError: 'total_energy'

Hi @C-K-Loan ,
you are right. The database was not formatted correctly. I added a pull request to fix this.

For now, this code should fix your databases:

import os
import shutil
import tempfile
from tqdm import tqdm
import numpy as np
from ase.db import connect
from schnetpack.datasets import ISO17


data_path = "iso17.db"

# fix databases
tmpdir = tempfile.mkdtemp("iso17")
for fold in ISO17.existing_folds:
    dbpath = os.path.join(data_path, "iso17", fold + ".db")
    tmp_dbpath = os.path.join(tmpdir, "tmp.db")
    with connect(dbpath) as conn:
        with connect(tmp_dbpath) as tmp_conn:
            tmp_conn.metadata = {
                "_property_unit_dict": {ISO17.energy: "eV", ISO17.forces: "eV/Ang"},
                "_distance_unit": "Ang",
                "atomrefs": {},
            }
            # add energy to data dict in db
            for idx in tqdm(range(len(conn)), f"parsing database file {dbpath}"):
                atmsrw = conn.get(idx + 1)
                data = atmsrw.data
                data[ISO17.forces] = np.array(data[ISO17.forces])
                data[ISO17.energy] = np.array([atmsrw.total_energy])
                tmp_conn.write(atmsrw.toatoms(), data=data)

    os.remove(dbpath)
    os.rename(tmp_dbpath, dbpath)
shutil.rmtree(tmpdir)

# check database
from schnetpack.transform import ASENeighborList
iso17data = ISO17(
    './iso17.db',
    fold = 'reference_eq',
    batch_size=10,
    num_train=11000,
    num_val=10000,
    transforms=[ASENeighborList(cutoff=5.)]
)
iso17data.prepare_data()
iso17data.setup()
print('Number of reference calculations:', len(iso17data.dataset))
print('Number of train data:', len(iso17data.train_dataset))
print('Number of validation data:', len(iso17data.val_dataset))
print('Number of test data:', len(iso17data.test_dataset))
print('Available properties:')

for p in iso17data.dataset.available_properties:
    print('-', p)

    

Let me know if this solves the issue!

This fixes the issue thank you @Stefaanhess