guanjq/targetdiff

how to get the batch.ligand_element_batch

Opened this issue · 0 comments

Hi, thank you for sharing such a good work. However, I am a little confused about how can I get batch.ligand_element_batch in the
def train(it):
model.train()
optimizer.zero_grad()
for _ in range(config.train.n_acc_batch):
batch = next(train_iterator).to(args.device)

        results = model.get_diffusion_loss(
            ligand_pos=batch.ligand_pos, #
            ligand_v=batch.ligand_atom_feature_full,
            batch_ligand=batch.ligand_element_batch
        )

Can you tell where can I find the processing operation of the ligand element batch?
Thank you