bashtage/ng-numpy-randomstate

A bug in MultithreadedRNG

Closed this issue · 8 comments

tkf commented

I think there is a bug in the following initialization step of MultithreadedRNG mentioned in the document:

self._random_states = [rs]
for _ in range(1, threads):
    _rs = randomstate.prng.xorshift1024.RandomState()
    rs.jump()
    _rs.set_state(rs.get_state())
    self._random_states.append(_rs)

Namely, the first two random states start from the same state. You can check it by:

>>> mrng = MultithreadedRNG(0, seed=1, threads=2)
>>> rs0, rs1 = mrng._random_states
>>> (rs0.get_state()['state'][0] == rs1.get_state()['state'][0]).all()
True

This can be fixed, for example, by calling _rs.jump() instead of rs.jump() like this:

self._random_states = [rs]
for _ in range(1, threads):
    _rs = randomstate.prng.xorshift1024.RandomState()
    _rs.set_state(rs.get_state())
    _rs.jump()  # <--- fixed
    self._random_states.append(_rs)

Here is the whole code in a Jupyter notebook:
https://gist.github.com/tkf/20d298879ff9d4d52212b3350e2b7262

Thanks for the great package, BTW!

tkf commented

Since this pattern is frequent for parallel random number generation, probably it makes sense to have something like

def independent_random_states(RandomState, seed, num):
    random_states = [RandomState(seed)]
    for _ in range(1, num):
        rs = RandomState()
        rs.set_state(random_states[-1].get_state())
        rs.jump()
        random_states.append(rs)
    return random_states

in the randomstate library itself?

Here is a quick test for it:

def samestate(rs0, rs1):
    return rs0.get_state() == rs1.get_state()


def test_independent_random_states():
    num = 3
    random_states = independent_random_states(
        randomstate.prng.xoroshiro128plus.RandomState,
        1, num)
    assert len(random_states) == num

    for i in range(num):
        assert samestate(random_states[i], random_states[i])
        for j in range(i + 1, num):
            assert not samestate(random_states[i], random_states[j])

Thanks for reporting this. An unfortunate typo.

I think the a better fix is to use the line:

self._random_states = [randomstate.prng.xorshift1024.RandomState().set_state(rs.get_state())]

The problem was that every call to jump was changing the state of the first element in the array.

The best fix would be something like (not tested):

        self._random_states = []
        for _ in range(0, threads - 1):
            _rs = randomstate.prng.xorshift1024.RandomState()
            _rs.set_state(rs.get_state())
            self._random_states.append(_rs)
            rs.jump()
        self._random_states.append(rs)

so that the original RandomState was the final one in the list.

tkf commented

Thanks for the update of the doc!

Why is it better to have "that the original RandomState was the final one in the list"?

Why is it better to have "that the original RandomState was the final one in the list"?

It is the cleanest since the original RS is "generating" all of the stated for use in other RSs, and so its state will always change. One could just use it to generate the state and then throw it away, in which case one coudl generate threads new RSs. This is slightly wasteful but probably a little cleaner.

tkf commented

I suppose it is nice from an aesthetic point of view since a single RS is mutating its state via .jump while my version it's called for every new RS? I mean, I just wondered if my version and your version of MultithreadedRNG generate different random numbers (presumably they do, right?). Or maybe some initialization has to be done for calling .jump right after .set_state, so that your version is more computationally efficient?

I don't think your original fix works for more than 2 threads since you are iterating over:

  1. Setting the new RS from the old
  2. Advancing the new one jump

If threads was > 2 then I think multiple would have the same state (I may be wrong here).

Pretty sure they would generate different series.

I think the example in the docs is probable the best order since you will have a sequence of RS of the form

  1. RS state set using using seed
  2. RS from 1, jumped 1 unit
  3. RS from 2, jumped 1 unit
    ...
    N. RS from N-1, jumped 1 unit.
tkf commented

You are right. I'm actually using independent_random_states so I wasn't paying much attention to my version of MultithreadedRNG. My bad.