Constructing potential energy functions for many systems on same device
Closed this issue · 6 comments
Context
We would like to define a Jax function referencing the potential energy functions of many systems at once. (E.g. to define a differentiable loss function in terms of a batch of molecules.)
Problem
The way the Jax wrappers are currently implemented requires to construct separate unbound impls for each system, which means (1) some start-up time required to instantiate all impls, and (2) a limit on the number of systems, imposed by the available GPU memory.
Reproduction
The following loop to construct a potential energy function for each molecule in FreeSolv will take a couple seconds per iteration, then crash when out of GPU memory.
import timemachine
print(timemachine.__version__)
# toggle-energy-minimization branch @ march 29, 2022
# https://github.com/proteneer/timemachine/tree/a2037e14ccefcdad2ac7465a139412893db27cf8
# (so that loop over mols doesn't have to call minimize_host_4d just to construct the potentials)
from timemachine.datasets import fetch_freesolv
from timemachine.md import enhanced
from timemachine.fe.functional import construct_differentiable_interface_fast
from timemachine.ff import Forcefield
ff = Forcefield.load_from_file("smirnoff_1_1_0_ccc.py")
mols = fetch_freesolv()
def prepare_energy_fxn(mol):
ubps, params, masses, coords, box = enhanced.get_solvent_phase_system(mol, ff, minimize_energy=False)
U_fxn = construct_differentiable_interface_fast(ubps, params)
return U_fxn, params
# crashes after a few dozen iterations / several minutes
energy_fxns = []
for mol in mols:
energy_fxns.append(prepare_energy_fxn(mol))
print(len(energy_fxns))
# ...
# def loss_fxn(ff_params):
# # ...
# # something that requires U_fxn(...) for (U_fxn, _) in energy_fxns
# # ...
# return loss
# ...
# _ = grad(loss_fxn)(ff_params)
# ...
(And a slightly loggier version of this loop that also queries GPU memory each iteration)
import subprocess as sp
import os
def get_gpu_memory(device=0):
"""adapted from https://stackoverflow.com/a/59571639"""
command = "nvidia-smi --query-gpu=memory.free --format=csv"
memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
return memory_free_values[device]
import numpy as np
import timemachine
print(timemachine.__version__)
# toggle-energy-minimization branch @ march 29, 2022
# https://github.com/proteneer/timemachine/tree/a2037e14ccefcdad2ac7465a139412893db27cf8
# (so that loop over mols doesn't have to call minimize_host_4d just to construct the potentials)
from tqdm import tqdm
import time
from timemachine.datasets import fetch_freesolv
from timemachine.md import enhanced
from timemachine.fe.functional import construct_differentiable_interface_fast
from timemachine.ff import Forcefield
from timemachine.lib.potentials import NonbondedInteractionGroup
ff = Forcefield.load_from_file("smirnoff_1_1_0_ccc.py")
mols = fetch_freesolv()
def nb_ig_from_nb(nonbonded_ubp):
lambda_plane_idxs = nonbonded_ubp.get_lambda_plane_idxs()
lambda_offset_idxs = nonbonded_ubp.get_lambda_offset_idxs()
beta = nonbonded_ubp.get_beta()
cutoff = nonbonded_ubp.get_cutoff()
ligand_idxs = np.array(np.where(lambda_offset_idxs != 0)[0], dtype=np.int32)
# switched from Nonbonded to NonbondedInteractionGroup in hopes of reducing memory consumption
nb_ig = NonbondedInteractionGroup(ligand_idxs, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)
return nb_ig
def prepare_energy_fxn(mol, use_interaction_group=True):
ubps, params, masses, coords, box = enhanced.get_solvent_phase_system(mol, ff, minimize_energy=False)
n_atoms = len(coords)
if use_interaction_group:
ubps_prime = ubps[:-1] + [nb_ig_from_nb(ubps[-1])]
U_fxn = construct_differentiable_interface_fast(ubps_prime, params)
else:
U_fxn = construct_differentiable_interface_fast(ubps, params)
return (U_fxn, params, n_atoms)
# hmm, still crashes after several minutes
energy_fxns = []
n_atoms_traj = [0]
device_idx = 1
free_memory_traj = [get_gpu_memory(device_idx)]
use_interaction_group = False
for mol in mols:
U_fxn, params, n_atoms = prepare_energy_fxn(mol, use_interaction_group=use_interaction_group)
energy_fxns.append((U_fxn, params))
n_atoms_traj.append(n_atoms)
# wait a bit before querying nvidia-smi
time.sleep(0.5)
free_memory_traj.append(get_gpu_memory(device_idx))
np.savez(f'memory_log_ig={use_interaction_group}', free_memory_traj=free_memory_traj, n_atoms_traj=n_atoms_traj)
Notes
- In this loop the systems contain about 2500 atoms each.
- @proteneer suggested to check whether the memory consumption per system is reduced by using the NonbondedInteractionGroup potential compared with the default Nonbonded potential -- in both cases the memory consumption appears to be ~64MB per system.
- Crash occurs in this loop when nvidia-smi reports ~1GB memory remaining
Possible solutions
- Refactor the Jax wrappers so that multiple systems can use the same underlying
unbound_impl
? (May address both the startup cost and the memory limit) - Reduce the memory consumption of each (neighborlist?) impl? (May be more invasive, may address only the memory limit)
- Use separate GPUs for separate systems?
- Use Jax reference implementation in these cases?
- ...
Refactor the Jax wrappers so that multiple systems can use the same underlying unbound_impl?
Actually, this route would probably require a deeper change than just the Jax wrappers (setters potential.set_idxs(...)
etc. are available for the potentials, but I think idxs
etc. are constant for the lifetime of the impls impl = potential.unbound_impl(precision)
...)
Reduce the memory consumption of each (neighborlist?) impl?
Not sure how much room there is for reduction here -- memory usage is close to what you would expect for 2500 * 3 float64s ... Oof, off by a factor of a thousand: should expect 60 kilobytes not 60 megabytes...
From looking into this with @proteneer and @jkausrelay : the source of the 64MB allocations is
timemachine/timemachine/cpp/src/nonbonded_all_pairs.cu
Lines 120 to 137 in 136da6b
Short term solution: Replace 256
with NUM_BINS = 96
(or 128
), which should have no impact on correctness, should have modest impact on efficiency, and will reduce the allocation size from ~64MB to ~3 (or ~8) MB per system.
Longer term solution: Extract this Hilbert index into a separate object that can be reused by multiple nonbonded impls, rather than creating a duplicate for each nonbonded impl.
Side quest: Possibly re-run a benchmarking script on each of a grid of settings of NUM_BINS
, to measure performance impact of this parameter.
(Side note: This index is also duplicated between nonbonded_all_pairs.cu
and nonbonded_interaction_group.cu
, so factoring this out of both files could be related to #639 .)
https://github.com/proteneer/timemachine/blob/master/timemachine/cpp/src/kernels/k_nonbonded.cu#L3 would also need to be updated, 256 is hardcoded here.
good call!