设置seed似乎有bug
Closed this issue · 1 comments
illrayy commented
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
class RandomDataset(Dataset):
def __getitem__(self, index):
return np.random.randint(0, 1000, 3)
def __len__(self):
return 8
def seed_everything(seed=11):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def seed_everything_wrap(seed=11):
def _init_fn(worker_id):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
return _init_fn
dataset = RandomDataset()
seed = 23
seed_everything(seed)
dataloader = DataLoader(dataset, batch_size=2, num_workers=4, worker_init_fn=seed_everything_wrap(seed))
for epoch in range(3):
print(f"epoch: {epoch}")
for batch in dataloader:
print(batch)
输出是
epoch: 0
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
epoch: 1
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
epoch: 2
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
tensor([[595, 742, 40],
[969, 950, 488]])
每个worker和epoch间的输出是一样的
illrayy commented
上述代码每次初始化worker的seed都是23,应该把worker_init_fn=seed_everything_wrap(seed)
改成worker_init_fn=seed_everything_wrap
输出
epoch: 0
tensor([[ 81, 169, 427],
[182, 596, 650]])
tensor([[977, 758, 359],
[110, 376, 906]])
tensor([[202, 234, 280],
[ 52, 717, 142]])
tensor([[337, 227, 759],
[236, 373, 282]])
epoch: 1
tensor([[582, 419, 678],
[522, 126, 356]])
tensor([[572, 10, 893],
[164, 870, 733]])
tensor([[107, 345, 285],
[702, 769, 716]])
tensor([[719, 632, 451],
[749, 765, 522]])
epoch: 2
tensor([[661, 964, 120],
[590, 768, 306]])
tensor([[405, 747, 811],
[331, 718, 927]])
tensor([[315, 521, 82],
[417, 380, 333]])
tensor([[856, 861, 761],
[422, 719, 197]])