proteneer/timemachine

Constructing potential energy functions for many systems on same device

Closed this issue · 6 comments

Context
We would like to define a Jax function referencing the potential energy functions of many systems at once. (E.g. to define a differentiable loss function in terms of a batch of molecules.)

Problem
The way the Jax wrappers are currently implemented requires to construct separate unbound impls for each system, which means (1) some start-up time required to instantiate all impls, and (2) a limit on the number of systems, imposed by the available GPU memory.

Reproduction
The following loop to construct a potential energy function for each molecule in FreeSolv will take a couple seconds per iteration, then crash when out of GPU memory.

import timemachine
print(timemachine.__version__)
# toggle-energy-minimization branch @ march 29, 2022
# https://github.com/proteneer/timemachine/tree/a2037e14ccefcdad2ac7465a139412893db27cf8
# (so that loop over mols doesn't have to call minimize_host_4d just to construct the potentials)

from timemachine.datasets import fetch_freesolv
from timemachine.md import enhanced
from timemachine.fe.functional import construct_differentiable_interface_fast
from timemachine.ff import Forcefield

ff = Forcefield.load_from_file("smirnoff_1_1_0_ccc.py")
mols = fetch_freesolv()


def prepare_energy_fxn(mol):
    ubps, params, masses, coords, box = enhanced.get_solvent_phase_system(mol, ff, minimize_energy=False)
    U_fxn = construct_differentiable_interface_fast(ubps, params)
    
    return U_fxn, params


# crashes after a few dozen iterations / several minutes
energy_fxns = []
for mol in mols:
    energy_fxns.append(prepare_energy_fxn(mol))
    print(len(energy_fxns))

# ...
# def loss_fxn(ff_params):
#    # ...
#    # something that requires U_fxn(...) for (U_fxn, _) in energy_fxns
#    # ...
#    return loss
# ...
# _ = grad(loss_fxn)(ff_params)
# ...

(And a slightly loggier version of this loop that also queries GPU memory each iteration)

import subprocess as sp
import os

def get_gpu_memory(device=0):
    """adapted from https://stackoverflow.com/a/59571639"""
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values[device]

import numpy as np

import timemachine
print(timemachine.__version__)
# toggle-energy-minimization branch @ march 29, 2022
# https://github.com/proteneer/timemachine/tree/a2037e14ccefcdad2ac7465a139412893db27cf8
# (so that loop over mols doesn't have to call minimize_host_4d just to construct the potentials)

from tqdm import tqdm

import time
from timemachine.datasets import fetch_freesolv
from timemachine.md import enhanced
from timemachine.fe.functional import construct_differentiable_interface_fast
from timemachine.ff import Forcefield
from timemachine.lib.potentials import NonbondedInteractionGroup

ff = Forcefield.load_from_file("smirnoff_1_1_0_ccc.py")
mols = fetch_freesolv()


def nb_ig_from_nb(nonbonded_ubp):
    lambda_plane_idxs = nonbonded_ubp.get_lambda_plane_idxs()
    lambda_offset_idxs = nonbonded_ubp.get_lambda_offset_idxs()
    beta = nonbonded_ubp.get_beta()
    cutoff = nonbonded_ubp.get_cutoff()
    
    ligand_idxs = np.array(np.where(lambda_offset_idxs != 0)[0], dtype=np.int32)

    # switched from Nonbonded to NonbondedInteractionGroup in hopes of reducing memory consumption
    nb_ig = NonbondedInteractionGroup(ligand_idxs, lambda_plane_idxs, lambda_offset_idxs, beta, cutoff)
    
    return nb_ig


def prepare_energy_fxn(mol, use_interaction_group=True):
    ubps, params, masses, coords, box = enhanced.get_solvent_phase_system(mol, ff, minimize_energy=False)
    n_atoms = len(coords)
    
    if use_interaction_group:
        ubps_prime = ubps[:-1] + [nb_ig_from_nb(ubps[-1])]
        U_fxn = construct_differentiable_interface_fast(ubps_prime, params)
    else:
        U_fxn = construct_differentiable_interface_fast(ubps, params)
    
    return (U_fxn, params, n_atoms)


# hmm, still crashes after several minutes
energy_fxns = []
n_atoms_traj = [0]

device_idx = 1

free_memory_traj = [get_gpu_memory(device_idx)]
use_interaction_group = False

for mol in mols:
    U_fxn, params, n_atoms = prepare_energy_fxn(mol, use_interaction_group=use_interaction_group)
    energy_fxns.append((U_fxn, params))
    n_atoms_traj.append(n_atoms)
    
    # wait a bit before querying nvidia-smi
    time.sleep(0.5)
    free_memory_traj.append(get_gpu_memory(device_idx))
    
    np.savez(f'memory_log_ig={use_interaction_group}', free_memory_traj=free_memory_traj, n_atoms_traj=n_atoms_traj)

Notes

  • In this loop the systems contain about 2500 atoms each.
  • @proteneer suggested to check whether the memory consumption per system is reduced by using the NonbondedInteractionGroup potential compared with the default Nonbonded potential -- in both cases the memory consumption appears to be ~64MB per system.
    image
  • Crash occurs in this loop when nvidia-smi reports ~1GB memory remaining

Possible solutions

  • Refactor the Jax wrappers so that multiple systems can use the same underlying unbound_impl? (May address both the startup cost and the memory limit)
  • Reduce the memory consumption of each (neighborlist?) impl? (May be more invasive, may address only the memory limit)
  • Use separate GPUs for separate systems?
  • Use Jax reference implementation in these cases?
  • ...

Refactor the Jax wrappers so that multiple systems can use the same underlying unbound_impl?

Actually, this route would probably require a deeper change than just the Jax wrappers (setters potential.set_idxs(...) etc. are available for the potentials, but I think idxs etc. are constant for the lifetime of the impls impl = potential.unbound_impl(precision)...)

Reduce the memory consumption of each (neighborlist?) impl?

Not sure how much room there is for reduction here -- memory usage is close to what you would expect for 2500 * 3 float64s ... Oof, off by a factor of a thousand: should expect 60 kilobytes not 60 megabytes...

From looking into this with @proteneer and @jkausrelay : the source of the 64MB allocations is

// initialize hilbert curve
std::vector<unsigned int> bin_to_idx(256 * 256 * 256);
for (int i = 0; i < 256; i++) {
for (int j = 0; j < 256; j++) {
for (int k = 0; k < 256; k++) {
bitmask_t hilbert_coords[3];
hilbert_coords[0] = i;
hilbert_coords[1] = j;
hilbert_coords[2] = k;
unsigned int bin = static_cast<unsigned int>(hilbert_c2i(3, 8, hilbert_coords));
bin_to_idx[i * 256 * 256 + j * 256 + k] = bin;
}
}
}
gpuErrchk(cudaMalloc(&d_bin_to_idx_, 256 * 256 * 256 * sizeof(*d_bin_to_idx_)));

Short term solution: Replace 256 with NUM_BINS = 96 (or 128), which should have no impact on correctness, should have modest impact on efficiency, and will reduce the allocation size from ~64MB to ~3 (or ~8) MB per system.

Longer term solution: Extract this Hilbert index into a separate object that can be reused by multiple nonbonded impls, rather than creating a duplicate for each nonbonded impl.

Side quest: Possibly re-run a benchmarking script on each of a grid of settings of NUM_BINS, to measure performance impact of this parameter.

(Side note: This index is also duplicated between nonbonded_all_pairs.cu and nonbonded_interaction_group.cu, so factoring this out of both files could be related to #639 .)

good call!

Was resolved by #692 .

(And the need to instantiate batches of GPU potentials for reweighting is avoided anyway when a linear basis function trick (#685, #931) is applicable)