A bug in MultithreadedRNG
Closed this issue · 8 comments
I think there is a bug in the following initialization step of MultithreadedRNG mentioned in the document:
self._random_states = [rs]
for _ in range(1, threads):
_rs = randomstate.prng.xorshift1024.RandomState()
rs.jump()
_rs.set_state(rs.get_state())
self._random_states.append(_rs)Namely, the first two random states start from the same state. You can check it by:
>>> mrng = MultithreadedRNG(0, seed=1, threads=2)
>>> rs0, rs1 = mrng._random_states
>>> (rs0.get_state()['state'][0] == rs1.get_state()['state'][0]).all()
TrueThis can be fixed, for example, by calling _rs.jump() instead of rs.jump() like this:
self._random_states = [rs]
for _ in range(1, threads):
_rs = randomstate.prng.xorshift1024.RandomState()
_rs.set_state(rs.get_state())
_rs.jump() # <--- fixed
self._random_states.append(_rs)Here is the whole code in a Jupyter notebook:
https://gist.github.com/tkf/20d298879ff9d4d52212b3350e2b7262
Thanks for the great package, BTW!
Since this pattern is frequent for parallel random number generation, probably it makes sense to have something like
def independent_random_states(RandomState, seed, num):
random_states = [RandomState(seed)]
for _ in range(1, num):
rs = RandomState()
rs.set_state(random_states[-1].get_state())
rs.jump()
random_states.append(rs)
return random_statesin the randomstate library itself?
Here is a quick test for it:
def samestate(rs0, rs1):
return rs0.get_state() == rs1.get_state()
def test_independent_random_states():
num = 3
random_states = independent_random_states(
randomstate.prng.xoroshiro128plus.RandomState,
1, num)
assert len(random_states) == num
for i in range(num):
assert samestate(random_states[i], random_states[i])
for j in range(i + 1, num):
assert not samestate(random_states[i], random_states[j])Thanks for reporting this. An unfortunate typo.
I think the a better fix is to use the line:
self._random_states = [randomstate.prng.xorshift1024.RandomState().set_state(rs.get_state())]
The problem was that every call to jump was changing the state of the first element in the array.
The best fix would be something like (not tested):
self._random_states = []
for _ in range(0, threads - 1):
_rs = randomstate.prng.xorshift1024.RandomState()
_rs.set_state(rs.get_state())
self._random_states.append(_rs)
rs.jump()
self._random_states.append(rs)
so that the original RandomState was the final one in the list.
Thanks for the update of the doc!
Why is it better to have "that the original RandomState was the final one in the list"?
Why is it better to have "that the original RandomState was the final one in the list"?
It is the cleanest since the original RS is "generating" all of the stated for use in other RSs, and so its state will always change. One could just use it to generate the state and then throw it away, in which case one coudl generate threads new RSs. This is slightly wasteful but probably a little cleaner.
I suppose it is nice from an aesthetic point of view since a single RS is mutating its state via .jump while my version it's called for every new RS? I mean, I just wondered if my version and your version of MultithreadedRNG generate different random numbers (presumably they do, right?). Or maybe some initialization has to be done for calling .jump right after .set_state, so that your version is more computationally efficient?
I don't think your original fix works for more than 2 threads since you are iterating over:
- Setting the new RS from the old
- Advancing the new one jump
If threads was > 2 then I think multiple would have the same state (I may be wrong here).
Pretty sure they would generate different series.
I think the example in the docs is probable the best order since you will have a sequence of RS of the form
- RS state set using using seed
- RS from 1, jumped 1 unit
- RS from 2, jumped 1 unit
...
N. RS from N-1, jumped 1 unit.
You are right. I'm actually using independent_random_states so I wasn't paying much attention to my version of MultithreadedRNG. My bad.