CRIPAC-DIG/GCA

Question about code

GeniusYx opened this issue · 0 comments

Sorry to bother you, I am confused that why edge_weights divided by edge_weights.mean()?

def drop_edge_weighted(edge_index, edge_weights, p: float, threshold: float = 1.):
edge_weights = edge_weights / edge_weights.mean() * p
edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold)
sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool)
return edge_index[:, sel_mask]