facebook/Ax

[GENERAL SUPPORT]: Improving run time of BO iterations

Opened this issue · 6 comments

Question

Hi, I'm trying to use BO on a problem in which the sampling time (time it takes to evaluate the function to be optimized at a specific arm) is relatively short. In this case the BO iteration time can become the computational bottleneck, and I'd like to use Ax's framework without some of the additional computations (that aren't strictly necessary for the BO) it performs

More specifically, I've seen that considerable part of the run time per iteration is spent on get_fit_and_std_quality_and_generalization_dict (https://github.com/facebook/Ax/blob/main/ax/modelbridge/cross_validation.py#L409) which is called by ModelSpec.gen (https://github.com/facebook/Ax/blob/main/ax/modelbridge/model_spec.py#L239).
I'd like to skip the computation of the cross validation each iteration as I'm not using it anyway, I tried looking for some flag in ModelBridge (or anywhere else for that matter) that allows me to skip that but wasn't able to find one. Is there a way to do that or should I approach this in a different manner (Like inhereting from ModelSpec and overwriting this method)?

Additionally, I've seen that TorchModelBridge._gen also computes best_point at each iteration (https://github.com/facebook/Ax/blob/main/ax/modelbridge/torch.py#L729) which I would also like to skip, but not sure if there's a simple flag that allows me to do so (This computation is very fast when using best in sample point so that's of less importance to me, however I would sometimes to use a custom TorchModelBridge that computes best_point by optimizing the posterior mean, and I'd still want to have a flag that allows me to skip that computation in _gen)

I've also seen that Experiment.fetch_data (https://github.com/facebook/Ax/blob/main/ax/core/experiment.py#L575) takes quite a bit of time on each iteration. but I wasn't able to understand what it really does and what makes it computationally expensive?

Below is a profiler on a single BO iteration using Models.BOTORCH_MODULAR with 250 samples.
image

Thanks!

Please provide any relevant code snippet if applicable.

No response

Code of Conduct

  • I agree to follow this Ax's Code of Conduct

Hi @RoeyYadgar. Thanks for reporting this. In this case get_fit_and_std_quality_and_generalization_dict seems to add a significant overhead, without directly contributing to the candidate generation (it's used for diagnostics and reporting purposes). This method uses LOOCV under the hood, which scales super-linearly in number of observations, which leads to it becoming a much larger overhead with 250 trials than we typically face. I plan to do broad profiling of Ax & BoTorch candidate generation in H1 and I'll add this to the list of things to address as part of that project.
In the short term, you can mock out this method to avoid the overhead it introduces.

with mock.patch("ax.modelbridge.model_spec.get_fit_and_std_quality_and_generalization_dict", return_value={}):
     ...

Hi @saitcakmak, Thanks a lot! The mocking out of this method is really helpful!

I also wanted to ask about the data fetching, it seems it spends quite a bit of time in Metric._wrap_trial_data_multi (which constructs a new BaseData object), and it's performed for each existing observation so it. Is there something I can do to improve on that?

What I ended up doing for now was to implement a custom Experiment and overwrite the fetch_data method. It works but It's not a very clean solution (I also the json storage feature of Ax and I had to define an encoder & decoder registry for it). Is there something cleaner I can do?

I am not that familiar with the internals of Experiment.fetch_data, but I'll look into it.

cc @mpolson64 - data fetching (rather lookup) seems to be adding a significant overhead

Wow, the difference between experiment.fetch_data and experiment.lookup_data is huge. With 1000 trials, lookup only takes 0.015 seconds while fetch takes 14.185.

If the data is pre-attached to the experiment (as in there is no Metric class that does some querying to fetch the metrics), these two are functionally identical. However, fetch loops through all trials & metrics to check if there is any new data to be retrieved, which takes a lot longer than just looking up the data that's readily attached.

@RoeyYadgar, I am guessing you're using an ask-tell setup here. If so, you can just replace fetch_data with lookup_data.

Thanks for pointing me out to that! I wasn't aware of the lookup_data method.

I found a few places where things could be improved. Experiment.fetch_data gets 10x faster (with 1000 trials & 2 metrics) after #3217