TorchRNG provides deterministic generation of PyTorch's pseudo random operations with a few extra lines. This feature is inspired from JAX library.
TorchRNG only depends on PyTorch. You can install TorchRNG via pip as follows:
$ pip install git+https://github.com/h-terao/torchrng
Take the following steps to use TorchRNG.
- Create PRNG key using
torchrng.PRNGKey
. - Transform stochastic callable object (i.e., functions or class methods) using
torchrng.deterministic
. The transformed function has a new argumentkey
as the first positional argument. (See the below example.) If you specify the same PRNGKey askey
, the transformed function will return same results. - If you need multiple PRNGKey, split PRNGKey using
torchrng.split
.
This is the short demo code of TorchRNG.
import torch
import torchrng
def random_fun(x: torch.Tensor) -> torch.Tensor:
# Example of the stochastic function.
return x * torch.rand_like(x)
tensor = torch.ones(size=(10, 10))
# This is a stochastic function, and results become different.
x = random_fun(tensor)
y = random_fun(tensor)
print("Estimated: False, Actual:", (x == y).all())
# Example of Step 1. If you use `torch.manual_seed` in advance, `seed=None` also works well.
key = torchrng.PRNGKey(seed=100)
# Example of Step 2. Because new_fun is deterministic and same key is passed, results are same.
new_fun = torchrng.deterministic(random_fun)
x = new_fun(key, tensor)
y = new_fun(key, tensor)
print("Estimated: True, Actual:", (x == y).all())
# Example of Step 3. If different PRNGKey is passed, results are different.
new_key, key = torchrng.split(key)
y = new_fun(new_key, tensor)
print("Estimated: False, Actual:", (x == y).all())
- If your machine has a lot of devices, TorchRNG becomes very slowly because TorchRNG uses
torch.random.fork_rng
many times. See details in PyTorch documentation. - Currently, the implementation of PRNGKey split is very simple and may not enough to keep randomness as much as PyTorch.