sisaman/LPGNN

randperm() received an invalid combination of arguments

Billy1900 opened this issue · 1 comments

When I run python train.py -d cora -m mbm, the error occurred as follows:

Traceback (most recent call last):
  File "train.py", line 147, in 
    main()
  File "train.py", line 141, in main
    batch_train_and_test(args)
  File "train.py", line 90, in batch_train_and_test
    dataset = load_dataset(name=args.dataset, feature_range=(0, 1), sparse=True, device=args.device)
  File "/home/nqluo/experiement/lpgnn/datasets.py", line 91, in load_dataset
    dataset = _available_datasets[name](root=os.path.join(root, name))
  File "/home/nqluo/anaconda3/envs/torch-gpu/lib/python3.7/site-packages/torch_geometric/datasets/planetoid.py", line 55, in __init__
    super(Planetoid, self).__init__(root, transform, pre_transform)
  File "/home/nqluo/anaconda3/envs/torch-gpu/lib/python3.7/site-packages/torch_geometric/data/in_memory_dataset.py", line 54, in __init__
    pre_filter)
  File "/home/nqluo/anaconda3/envs/torch-gpu/lib/python3.7/site-packages/torch_geometric/data/dataset.py", line 92, in __init__
    self._process()
  File "/home/nqluo/anaconda3/envs/torch-gpu/lib/python3.7/site-packages/torch_geometric/data/dataset.py", line 165, in _process
    self.process()
  File "/home/nqluo/anaconda3/envs/torch-gpu/lib/python3.7/site-packages/torch_geometric/datasets/planetoid.py", line 109, in process
    data = data if self.pre_transform is None else self.pre_transform(data)
  File "/home/nqluo/experiement/lpgnn/transforms.py", line 76, in __call__
    perm = torch.randperm(num_nodes_with_class, generator=self.rng)
TypeError: randperm() received an invalid combination of arguments - got (int, generator=NoneType), but expected one of:
 * (int n, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (int n, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

I wonder it is the thing related to my torch version? my torch is 1.4.0+cu100

Yes, it is most likely due to your PyTorch version. You can remove "generator=self.rng" in transform.py line 76 to make it work (this option is not actually used in the code, just for debugging), though I cannot guarantee the rest of the code would run without error with your PyTorch version.