how to get the batch.ligand_element_batch
Opened this issue · 0 comments
545487677 commented
Hi, thank you for sharing such a good work. However, I am a little confused about how can I get batch.ligand_element_batch in the
def train(it):
model.train()
optimizer.zero_grad()
for _ in range(config.train.n_acc_batch):
batch = next(train_iterator).to(args.device)
results = model.get_diffusion_loss(
ligand_pos=batch.ligand_pos, #
ligand_v=batch.ligand_atom_feature_full,
batch_ligand=batch.ligand_element_batch
)
Can you tell where can I find the processing operation of the ligand element batch?
Thank you