multiple GPU devices simulation and training of one dynamic system in brainpy
Opened this issue · 15 comments
Hi, Chaoming:
I am trying to do simulation and training of a dynamic system (a self customized RNN based on brainpy, https://github.com/Dr-Chen-Xiaoyu/DecoModel) with very huge dimension and time steps. The memory usage is out of one single GPU device.
I believe this could be solved by running brainpy on multiple GPU devices with its own sharding
method, just like jax's sharding
or pytorch's torch.nn.DataParallel
. A simplified case of RNN training is provided below, and change the dimension of RNN to very huge (maybe >1000) as well as the input output tensor (maybe >1000^3). Maybe you could modify this code with brainpy's sharding
and make it as part of brainpy's tutorial if this is a general demand of users.
best,
Xiaoyu
The example code:
# %%
import os,jax
import numpy as np
import matplotlib.pyplot as plt
import brainpy as bp
import brainpy.math as bm
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1" # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
bm.set_mode(bm.training_mode)
print('bp version:', bp.__version__)
print(jax.local_devices())
#bp version: 2.4.6.post5
#[cuda(id=0), cuda(id=1)]
# %%
class RNN(bp.DynamicalSystemNS):
def __init__(self, num_in, num_hid, num_out, batch_size=1):
super(RNN, self).__init__()
bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))
# define parameters
self.num_in = num_in
self.num_hid = num_hid
self.num_out = num_out
# define variables
self.state = bm.Variable(bm.zeros((batch_size, num_hid)), batch_axis=0)
# define weights
self.win = bm.TrainVar(bm.random.normal(0., 1., size=(num_in, num_hid)))
self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out)))
def reset_state(self, batch_size):# this function defines how to reset the mode states
self.state.value = bm.zeros((batch_size, self.num_hid))
def update(self, x):# this function defined how the model update its state and produce its output
self.state.value = bm.tanh( bm.matmul(x, self.win) + bm.matmul(self.state, self.wrec) )
return bm.matmul(self.state, self.wout)
# initialize model
bm.random.seed(123)
dim_in =1
dim_hid=10
dim_out=1
batch_size=1
model = RNN(dim_in, dim_hid, dim_out , batch_size)
# %%
# generate some data
Nsample = 500
X_train = bm.random.normal(0.,1., size=(batch_size ,Nsample,dim_in)) #(Batch,Time,dim)
Y_train = bm.random.normal(10.,1., size=(batch_size, Nsample,dim_out))
def plot_model_predict(model,X_train,Y_train):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
Y_model = runner.run(inputs=X_train)
plt.plot(X_train[0,:,:])
plt.plot(Y_train[0,:,:])
plt.plot(Y_model[0,:,:])
plt.show()
plot_model_predict(model,X_train,Y_train)
# %%
# training
def loss_fun(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss
grad_fun = bm.grad(loss_fun,grad_vars=model.train_vars().unique(),return_value=True)
opt = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())
@bm.jit
def train(xs, ys):
grads, loss = grad_fun(xs, ys)
opt.update(grads)
return loss
losses=[]
for _ in range(1000):
losses.append(train(X_train,Y_train))
plt.plot(losses);plt.show()
plot_model_predict(model,X_train,Y_train)
I think I might find the way to sharding bm.array based on JAX's tutorial https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html :
# %%
import jax
import jax.numpy as jnp
import os
import numpy as np
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1" # specify which GPU(s) to be used
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import brainpy as bp
import brainpy.math as bm
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
print('bp version:', bp.__version__)
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PositionalSharding
from jax.sharding import PartitionSpec as P
# %%
def get_sharding_details(sharded_data):
# We can get detailed information for each shard
print("="*75)
for i, shard in enumerate(sharded_data.global_shards):
print(f"Shard no: {i:>5}")
print(f"Device: {str(shard.device):>32}")
print(f"Data shape: {str(shard.data.shape):>8}")
print(f"Data slices: {str(shard.index):>22}")
print("="*75)
# %%
devices = mesh_utils.create_device_mesh((len(jax.local_devices()),))
print(f"Device Array: {devices}")
# Create a mesh from the device array
mesh = Mesh(devices, axis_names=("ax"))
# Define sharding with a partiton spec
sharding = NamedSharding(mesh, P("ax"))
print(mesh)
# %%
a = jnp.ones((1000,1000,3))
get_sharding_details(a)
print("\nafter sharding:\n")
# Shard the data
b = jax.device_put(a, sharding)
get_sharding_details(b)
# %%
c = bm.ones((1000,1000,3))
get_sharding_details(c.value)
print("\nafter sharding:\n")
# Shard the data
d = bm.sharding.partition_by_sharding(c, sharding)
get_sharding_details(d.value)
Maybe just sharding the input output bm.array tensor along the batch axis, and then let it automatically calculate on multi-GPUs ?
Just some thought 😊
print is something like that before- and after-sharding array:
===========================================================================
Shard no: 0
Device: cuda:0
Data shape: (1000, 1000, 3)
Data slices: (slice(None, None, None), slice(None, None, None), slice(None, None, None))
===========================================================================
after sharding:
===========================================================================
Shard no: 0
Device: cuda:0
Data shape: (500, 1000, 3)
Data slices: (slice(0, 500, None), slice(None, None, None), slice(None, None, None))
===========================================================================
Shard no: 1
Device: cuda:1
Data shape: (500, 1000, 3)
Data slices: (slice(500, 1000, None), slice(None, None, None), slice(None, None, None))
===========================================================================
Thanks for the question. Sorry for the slow response. I will check it later.
Hi, chaoming @chaoming0625
Maybe this issue is a bit hard with too many engineering works to achieve. 🫡
I just have an idea about a quick and cheap solution of this issue. As to #663 , if any built-in or customized brainpy dynamical system class could be automatically transformed into Flax's RNN cell using bp.dnn.ToFlaxRNNCell(). Then, we could just do multi-GPU parallel training using Flax (https://flax.readthedocs.io/en/latest/guides/parallel_training/index.html). 🤖
best,
Xiaoyu Chen
yes, the idea is simple. I will give you the solution soon.
Here is my example of using multiple GPUs. I marked the key code by using the comment [KEY]
.
import os
import jax
import brainpy as bp
import brainpy.math as bm
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
bm.set_mode(bm.training_mode)
print('bp version:', bp.__version__)
print(jax.local_devices())
# bp version: 2.4.6.post5
# [cuda(id=0), cuda(id=1)]
# %%
class RNN(bp.DynamicalSystemNS):
def __init__(self, num_in, num_hid, num_out, batch_size=1):
super(RNN, self).__init__()
bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))
# define parameters
self.num_in = num_in
self.num_hid = num_hid
self.num_out = num_out
# define variables [KEY]
self.state = bp.init.variable(bm.zeros, num_hid, batch_size, axis_names=['hidden'])
# self.state = bm.Variable(bm.zeros((batch_size, num_hid)), batch_axis=0)
# define weights [KEY]
self.win = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_in, num_hid), axis_names=[None, 'hidden']))
self.wrec = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_hid, num_hid), axis_names=[None, 'hidden']))
self.wout = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_hid, num_out), axis_names=['hidden', None]))
# self.win = bm.TrainVar(bm.random.normal(0., 1., size=(num_in, num_hid)))
# self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
# self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out)))
def reset_state(self, batch_size): # this function defines how to reset the mode states
self.state.value = bp.init.variable_(bm.zeros, (self.num_hid,), batch_size, axis_names=['hidden'])
def update(self, x): # this function defined how the model update its state and produce its output
self.state.value = bm.tanh(bm.matmul(x, self.win) + bm.matmul(self.state, self.wrec))
return bm.matmul(self.state, self.wout)
with bm.sharding.device_mesh(jax.devices(), ['hidden']): # [KEY]
# initialize model
bm.random.seed(123)
dim_in = 1
dim_hid = 10
dim_out = 1
batch_size = 1
model = RNN(dim_in, dim_hid, dim_out, batch_size)
# %%
# generate some data
Nsample = 500
X_train = bm.random.normal(0., 1., size=(batch_size, Nsample, dim_in)) # (Batch,Time,dim)
Y_train = bm.random.normal(10., 1., size=(batch_size, Nsample, dim_out))
# training
def loss_fun(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss
grad_fun = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), return_value=True)
opt = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())
@bm.jit
def train(xs, ys):
grads, loss = grad_fun(xs, ys)
opt.update(grads)
return loss
losses = []
for _ in range(1000):
losses.append(train(X_train, Y_train))
The concept is very simple.
- initialize a context manager to setup a device mesh. Here
with bm.sharding.device_mesh(devices, ['hidden']):
...
means that the hidden
dimension will be partitioned on the given devices.
Note that the devices
should be the same dimension as the hidden. For example, if you want to partition the model onto two-dimensional devices by input
and hidden
, We should set up a context as:
with bm.sharding.device_mesh(np,asarray(jax.devices(), (2, 2)), ['input', 'hidden']):
...
-
Initializing the variable of weights by using
brainpy.init.variable_(...., axis_names=['input', 'hidden'])
. The data will be automatically partitioned on the devices if the given axis name matches the device mesh axis. -
using
brainpy.math.jit
. This is the key to the parralelization. All functions should have ajit
decorator, otherwise, the model will not be parallelized according to the setting.
Please tell me whether the above code works.
Please also see an example of TPU multi-device partition examples of COBA-HH network model.
By the way, I apologize for the very late response!
Thanks for the feedback!
One more question about the details. it seems that you partition the model (the hidden states of this RNN) into two GPUs. Why not partition along the batch axis? it seems more nature for users.
This is a good idea. While, if the batch size is the challenge hindering the training of the model on one GPU, we can decrease the batch size, rather than partition it on multiple devices. One more difficult situation is that the model is too big to install on one device. For such cases, we can partition the model on multiple devices. For example, simulating a very large-scale SNN model (usually there are no batch sizes).
Partitioning on hidden states, and their interaction matrix is a simple model parallelization method.
Okay, I see.
By the way, I found that in the code of model definition, only change one line about the model state variable is enough for parallelization. No need to change the weights TrainVar with axis_names=['input' ,'hidden']
things.
# define variables
self.state = bp.init.variable(bm.zeros, batch_size, num_hid, axis_names=['hidden'], batch_axis_name=['batch']) #<<<关键点
# define weights
self.win = bm.TrainVar(bm.random.normal(0., 1., size=(num_in, num_hid))) # 不用改
self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out)))
Thanks again for the help👍👍👍