facebookresearch/beanmachine

Using var_names argument in Arviz plots raise 'RVIdentifier' object has no attribute 'startswith'

feynmanliang opened this issue · 10 comments

Issue Description

var_names is a mechanism in arviz for selectively plotting a subset of posterior samples.
It is important for avoiding excessive plots when a high dimensional posterior is targeted.
However, combining this argument with an InferenceData obtained from beanmachine
results in an exception.

Steps to Reproduce

import beanmachine.ppl as bm
import torch.distributions as dist
import arviz as az
foo = bm.random_variable(lambda: dist.MultivariateNormal(torch.zeros(4), torch.eye(4)))
xx = bm.SingleSiteRandomWalk().infer(queries=[foo()], observations={}, num_samples=100, num_chains=2)
az.plot_trace(xx.to_inference_data(), var_names=['xx()'])

will raise

'RVIdentifier' object has no attribute 'startswith'
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-37-cb5a8769793b> in <module>
      3 idata = samples.to_inference_data()
      4 
----> 5 az.plot_trace(
      6     idata, var_names=['prevalence()']
      7 )
/mnt/xarfuse/uid-25957/a9b3979e-seed-nspid4026533620_cgpid78576781-ns-4026533617/arviz/plots/traceplot.py in plot_trace(data, var_names, filter_vars, transform, coords, divergences, kind, figsize, rug, lines, circ_var_names, circ_var_units, compact, compact_prop, combined, chain_prop, legend, plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs, trace_kwargs, rank_kwargs, axes, backend, backend_config, backend_kwargs, show)
    182         coords_data = transform(coords_data)
    183 
--> 184     var_names = _var_names(var_names, coords_data, filter_vars)
    185 
    186     if compact:
/mnt/xarfuse/uid-25957/a9b3979e-seed-nspid4026533620_cgpid78576781-ns-4026533617/arviz/utils.py in _var_names(var_names, data, filter_vars)
     46             all_vars = list(data.data_vars)
     47 
---> 48         all_vars_tilde = [var for var in all_vars if var.startswith("~")]
     49         if all_vars_tilde:
     50             warnings.warn(
/mnt/xarfuse/uid-25957/a9b3979e-seed-nspid4026533620_cgpid78576781-ns-4026533617/arviz/utils.py in <listcomp>(.0)
     46             all_vars = list(data.data_vars)
     47 
---> 48         all_vars_tilde = [var for var in all_vars if var.startswith("~")]
     49         if all_vars_tilde:
     50             warnings.warn(
AttributeError: 'RVIdentifier' object has no attribute 'startswith'

Expected Behavior

A trace plot containing only the xx() RV

Additional Context

stringifying RVIdentifiers sidesteps this issue, eg

idata = az.convert_to_inference_data({
    str(k):samples[k] for k in samples
})

@feynmanliang for this issue and #1563 could you say what version of arviz was used? I believe this error was fixed in a recent version.

Arviz 0.11.2 for both

I don't get this error in the latest versions of arviz, but I'm getting another error:

KeyError: 'var names: "[\'xx()\'] are not present" in dataset'

I'm not sure the semantics of var_names support this. The best way to implement this is using xarray's rename

import beanmachine.ppl as bm
import torch.distributions as dist
import arviz as az
foo = bm.random_variable(lambda: dist.MultivariateNormal(torch.zeros(4), torch.eye(4)))
xx = bm.SingleSiteRandomWalk().infer(queries=[foo()], observations={}, num_samples=100, num_chains=2)
xx = xx.to_inference_data().rename({foo(): "xx"})
az.plot_trace(xx, var_names="xx")

@feynmanliang does this adequately address your concerns?

Apologies, I didn't mean for the lambda name inference / rename() to be a part of this issue.

With arviz 0.12.1, I just tried

import beanmachine.ppl as bm
import torch
import torch.distributions as dist
import arviz as az
@bm.random_variable
def xx():
  return dist.MultivariateNormal(torch.zeros(4), torch.eye(4))
mcs = bm.SingleSiteRandomWalk().infer(queries=[xx()], observations={}, num_samples=100, num_chains=2)
az.plot_trace(mcs.to_inference_data(), var_names=[xx()])

and it succeeds, but fails on 0.11.2. So yes, thank you for helping me figure out this is an arviz version issue!