blackjax-devs/blackjax

Remove `transform` from MCLMC, and place it in `run_inference_loop`

reubenharry opened this issue · 0 comments

Current behavior

The MCLMC algorithm has a transform parameter, which is nothing specific to MCLMC, but just allows the user to specify a projection of the possibly high dimensional position space to a lower dimensional one. This is place inside the kernel, so that at each step, the kernel returns an info object which contains:

MCLMCInfo(
            transformed_position=transform(position),
            logdensity=logdensity,
            energy_change=kinetic_change - logdensity + state.logdensity,
            kinetic_change=kinetic_change * (dim - 1),
        )

Desired behavior

This was done out of memory considerations. But I don't think there's any need for MCLMC to be special in this way. Instead, the run_inference_loop function should have some similar functionality, and it should be removed from MCLMC.