ISO17 data load bug
Closed this issue · 2 comments
C-K-Loan commented
Looks like there is a bug for loading iso17data to torch tensors
https://colab.research.google.com/drive/1LkySYJWJoFScx1w4DKD9SJ5XOXtsmYU1?usp=sharing
! pip install schnetpack
from schnetpack.data import ASEAtomsData
from schnetpack.datasets import ISO17
from schnetpack.transform import ASENeighborList
iso17data = ISO17(
'./iso17.db',
fold = 'reference_eq',
batch_size=10,
num_train=11000,
num_val=10000,
transforms=[ASENeighborList(cutoff=5.)]
)
iso17data.prepare_data()
iso17data.setup()
print('Number of reference calculations:', len(iso17data.dataset))
print('Number of train data:', len(iso17data.train_dataset))
print('Number of validation data:', len(iso17data.val_dataset))
print('Number of test data:', len(iso17data.test_dataset))
print('Available properties:')
for p in iso17data.dataset.available_properties:
print('-', p)
iso17data.dataset[0]
throws
KeyError Traceback (most recent call last)
[<ipython-input-1-2b61099d20f0>](https://localhost:8080/#) in <cell line: 26>()
24 print('-', p)
25
---> 26 iso17data.dataset[0]
1 frames
[/usr/local/lib/python3.10/dist-packages/schnetpack/data/atoms.py](https://localhost:8080/#) in _get_properties(self, conn, idx, load_properties, load_structure)
345 for pname in load_properties:
346 properties[pname] = (
--> 347 torch.tensor(row.data[pname].copy()) * self.conversions[pname]
348 )
349
KeyError: 'total_energy'
stefaanhessmann commented
Hi @C-K-Loan ,
you are right. The database was not formatted correctly. I added a pull request to fix this.
For now, this code should fix your databases:
import os
import shutil
import tempfile
from tqdm import tqdm
import numpy as np
from ase.db import connect
from schnetpack.datasets import ISO17
data_path = "iso17.db"
# fix databases
tmpdir = tempfile.mkdtemp("iso17")
for fold in ISO17.existing_folds:
dbpath = os.path.join(data_path, "iso17", fold + ".db")
tmp_dbpath = os.path.join(tmpdir, "tmp.db")
with connect(dbpath) as conn:
with connect(tmp_dbpath) as tmp_conn:
tmp_conn.metadata = {
"_property_unit_dict": {ISO17.energy: "eV", ISO17.forces: "eV/Ang"},
"_distance_unit": "Ang",
"atomrefs": {},
}
# add energy to data dict in db
for idx in tqdm(range(len(conn)), f"parsing database file {dbpath}"):
atmsrw = conn.get(idx + 1)
data = atmsrw.data
data[ISO17.forces] = np.array(data[ISO17.forces])
data[ISO17.energy] = np.array([atmsrw.total_energy])
tmp_conn.write(atmsrw.toatoms(), data=data)
os.remove(dbpath)
os.rename(tmp_dbpath, dbpath)
shutil.rmtree(tmpdir)
# check database
from schnetpack.transform import ASENeighborList
iso17data = ISO17(
'./iso17.db',
fold = 'reference_eq',
batch_size=10,
num_train=11000,
num_val=10000,
transforms=[ASENeighborList(cutoff=5.)]
)
iso17data.prepare_data()
iso17data.setup()
print('Number of reference calculations:', len(iso17data.dataset))
print('Number of train data:', len(iso17data.train_dataset))
print('Number of validation data:', len(iso17data.val_dataset))
print('Number of test data:', len(iso17data.test_dataset))
print('Available properties:')
for p in iso17data.dataset.available_properties:
print('-', p)
Let me know if this solves the issue!
C-K-Loan commented
This fixes the issue thank you @Stefaanhess