facebookresearch/beanmachine

Simulate calls queried functions

BerndSchuscha opened this issue · 1 comments

Issue Description

If one uses the function simulate the graphical network is evalatued N times per sample, were N is the number of @bm.random_variable in the network

Steps to Reproduce

@bm.random_variable
def A():
return dist.Normal(1,1)

@bm.random_variable
def B():
return dist.Normal(1,1)

@bm.random_variable
def C():
print('C')
return dist.Normal(A()+B(),1)

obs_queries=[C()]
predictives = bm.simulate(obs_queries, num_samples=1)
-> 3 calls

(One also can do this with a profiler)

Expected Behavior

1 Call per sample

System Info

Please provide information about your setup

  • PyTorch Version 1.12.1
  • Python version 3.9

Additional Context

This happens because simulate uses "inference = SingleSiteAncestralMetropolisHastings()" for a sample step which is exactly this N times evaluation.

A solution could be to substitute the "next" in predictivy.py with an function which just uses a random proposer to generate a now world and return this would. (a adapted send method from "sampler.py")

Is this soluation viable or do I break something else with this along the line?

That should be workable.