divelab/DIG

Documentation Discrepancy: Incompatibility of DIG with PyTorch Geometric 2.3.0

Qianli-Wu opened this issue · 1 comments

Hello,

I noticed an issue regarding the doc of DIG.

The doc states that DIG requires PyTorch Geometric (>=2.0.0). However, it seems that the DIG library is incompatible with PyTorch Geometric version 2.3.0.

In the 2.3.0 release of torch geometric, the API for the MessagePassing class was changed, as detailed in this commit. For example, the method __check_input__ was changed to _check_input`, which leads to incompatibility with certain lines in the DIG codebase, for instance:

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
size = self.__check_input__(edge_index, size)

coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
size, kwargs)

Also, the Dataset class was turned into an abstract class, requiring the implementation of two abstract methods len and get.

    @abstractmethod
    def len(self) -> int:
        r"""Returns the number of graphs stored in the dataset."""
        raise NotImplementedError

    @abstractmethod
    def get(self, idx: int) -> BaseData:
        r"""Gets the data object at index :obj:`idx`."""
        raise NotImplementedError

Consequently, classes such as MarginalSubgraphDataset need to implement these methods to function correctly:

class MarginalSubgraphDataset(Dataset):
def __init__(self, data, exclude_mask, include_mask, subgraph_build_func):
self.num_nodes = data.num_nodes
self.X = data.x
self.edge_index = data.edge_index
self.device = self.X.device
self.label = data.y
self.exclude_mask = torch.tensor(exclude_mask).type(torch.float32).to(self.device)
self.include_mask = torch.tensor(include_mask).type(torch.float32).to(self.device)
self.subgraph_build_func = subgraph_build_func
def __len__(self):
return self.exclude_mask.shape[0]
def __getitem__(self, idx):
exclude_graph_X, exclude_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.exclude_mask[idx])
include_graph_X, include_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.include_mask[idx])
exclude_data = Data(x=exclude_graph_X, edge_index=exclude_graph_edge_index)
include_data = Data(x=include_graph_X, edge_index=include_graph_edge_index)
return exclude_data, include_data

This discrepancy might cause confusion for users attempting to install and use DIG with PyTorch Geometric 2.3.0.

Hi, I'm using PyTorch Geometric 2.4.0. and I'm still facing the same problem as described above. Currently, I'm trying to run the tutorial for SubgraphX. The visualization of the results of SubgraphX does not work and I get:
"TypeError: Can't instantiate abstract class MarginalSubgraphDataset with abstract methods get, len"
Could there be another problem?