/rl_cmdstanr

Mini tutorial on {cmdstanr} for the simple Rescorla-Wagner RL model

Primary LanguageR

GitHub language count Twitter Follow

Mini tutorial on {cmdstanr} for the simple Rescorla-Wagner RL model

v1 - Lei Zhang - 10 Aug 2020

Note: this mini tutorial assumes you already know a bit about RStan, and now consider switching to cmdstanr.

tl;dr

f = rw_run('rw_hba', saveFit = T, test = F)

longer explanation

This repository contains:

root
  ├─ data       # example data for a simple 2-armed bandit task
  ├─ scripts    # stan model and R function
  ├─ stanfits   # to save fitted objects

{CmdStanR} is a lightweight interface to Stan for R users that provides an alternative to the traditional {RStan} interface. It employs the most updated functionalities in cmdstan and provides nicer tools to operate the posterior. In addition, it seems that {CmdStanR} is the only interface that supports multithreading. See here for a comparison between {CmdStanR} and {RStan}.

For researchers using computational modeling to understand cognition, if you have used {RStan} and now consider switching to {CmdStanR}, here is a mini tutorial for you. In essence, all your *.stan models stay intact, and all you need to change is the wrapper function where you call stan models.

As an example, here I have a simple Rescorla-Wagner model for a 2-armed bandit task, with the hierarchical Bayesian approach (HBA) (Ahn et al., 2017). The model is called rw_hba.stan. Plus, the wrapper function is called rw_run.R, and the main input argument is the model string modelStr = 'rw_hba'. All commands in this wrapper function have been updated to be compatible with the {CmdStanR} package.

To run the model, simply call:

f = rw_run('rw_hba', saveFit = T, test = F)

In case you wonder, the core part is to create a cmdstan object:

mod = cmdstan_model(modelFile)

Then, to actually run MCMC sampling, call the $sample methods in the mod object.

fit = mod$sample(
                data = dataList,
                chains = nChains,
                parallel_chains = nCores,
                refresh = nRefresh,
                iter_warmup = nWarmup,
                iter_sampling = nIters - nWarmup,
                seed = est_seed,
                max_treedepth = treedepth,
                adapt_delta = adapt
                )

In addition, I let the function print out stan diagnostic messages; also, I include the computation of LOO for model comparison.

Note: when test = T, the wrapper function will only run 1 chain with 2 samples. This test mode is ideal for debugging stan models.


For bug reports, please contact Lei Zhang (lei.zhang@univie.ac.at).

Thanks to Markdown Cheatsheet and shields.io.


LICENSE

This license (CC BY-NC 4.0) gives you the right to re-use and adapt, as long as you note any changes you made, and provide a link to the original source. Read here for more details.