grf-labs/policytree

Deprecate one vs all `multi_causal_forest` in favor of new GRF estimator

Closed this issue · 4 comments

Once GRF 2.0 has been released, policytree should use the new multivariate causal forest instead.

  • multi_causal_forest should still work but dispatch to GRF's multi_arm_causal_forest, with a warning.
  • add note to the MAPL replication code in experiments that it was run with policytree version 1.0 and grf 1.1.
  • update the double_robust_scores function (and remove old ones)
  • update README/doc references
  • bump minimum GRF version to 2.0 and bump policytree version to 1.1

Just adding a memory scaling note for future reference: with k arms the one vs. all approach scales linearly in k, O(k*c) where c is the memory footprint of one forest forest. grf's multi arm forest (grf-labs/grf#748) scales worse since we need to store sufficient statistics in each tree node to run a multivariate kernel-weighted regression - the worst case storage is O(k^2*n) where n is the total number of nodes in the forest.

I understand from this issue that multi_causal_forest would be deprecated soon. In the meantime, I had a quick question. Do we know whether sample.weights option in multi_causal_forest behaves as expected? Thanks for your time!

Hi @austindenteh, yes, but because of the "one vs all" approach the sample weights may interact differently for each treatment (https://grf-labs.github.io/policytree/articles/policytree.html#multi-causal-forest-treatment-estimates-and-baselines). It may be worth waiting for GRF 2.0 or trying out the "multi_arm" causal forest in the GRF development version instead: https://grf-labs.github.io/grf/reference/multi_arm_causal_forest.html

Here the sample weights has a crystal clear interpretation: for a multivariate treatment matrix W and scalar outcome Y GRF's prediction for sample x is the kernel weighted regression Y = c + tau W + eps, with kernel weights a(x). With sample weights w, we replace the kernel weights with a'(x) = w * a(x).

(just a heads up that this is a change from how sample weighting was done in GRF < 2.0, the issue describing this is here: grf-labs/grf#796)

Hi @erikcs, this is a very helpful response and will follow suggestions! Thanks for your time.