The cosine bump basis functions used in FilterNet seem to be different from the formula in Pillow et.al. J Neurosci 2005
CloudyDory opened this issue · 0 comments
The consine bump basis function of LGN cells in Pillow et.al. J Neurosci 2005 is:
B(t) = (cos(log[t+tau] - phi) + 1) / 2, if phi-pi < log[t+tau] < phi+pi
However, the actual function makeBasis_StimKernel
implemented in bmtk/simulator/filternet/lgnmodel/fitfuns.py
(from line 28-81 and line 97-122) seems to be more complex than that and lacks documentation (the tutorial in bmtk/docs/tutorial/07_filter_models.ipynb
is not enough). For example, the values of variable b
(line 33), variable ylim
(line 37), and variable db
(line 40) are all hard-coded, do not appear in the original equation, and I can't find an explanation on what they mean. Besides, I am also not sure which variable in the original equation does kpeaks
represent, and why there is an extra normalization step.
There is also a demo in bmtk/docs/tutorial/helpers/filternet_images_helpers.py
that draws a figure of basis function with different parameters. I am able to reproduce that figure with a much simpler code that summarize the actual computation bmtk performs. The computation in get_temporal_kernel()
in the following code is indeed different from the original equation in Pillow et.al. J Neurosci 2005. So, what is the reason that bmtk uses a different formula?
import numpy as np
import matplotlib.pyplot as plt
#%% Helper functions
def get_temporal_kernel(t, kpeaks, delay, weights):
'''
Inputs:
t: [length,1] array
kpeaks: [1,2] array
delay: [1,2] array
weights: [2,] array
Output:
kernel: [length,] array
'''
log_t = np.log(t+1.3-delay) - np.log(kpeaks) # Use 1.3 not 0.3 here to compensate for the extra 1-point offset in line 65 of `fitfuns.py`.
log_t_pi = np.clip(np.pi*log_t/get_temporal_kernel.db2, -np.pi, np.pi)
log_t_pi[np.isnan(log_t_pi)] = -np.pi
basis = (np.cos(log_t_pi) + 1) / 2.0
basis_norm = basis / np.linalg.norm(basis, ord=2, axis=0)
kernel = basis_norm @ weights
return kernel
get_temporal_kernel.db2 = 2.0 * np.diff(np.log([100.3, 200.3])) # 2*db in `fitfuns.py`
#%% Plot kernels with different parameters
weights = np.array([[30.0, -20.0], [30.0, -1.0], [15.0, -20.0]])
kpeaks = np.array([[3.0, 5.0], [3.0, 30.0], [20.0, 40.0]])
delays = np.array([[.0, 0.0], [0.0, 60.0], [20.0, 60.0]])
t = np.expand_dims(np.arange(0,150,1), axis=1) # milliseconds
fig, axes = plt.subplots(3, 3, figsize=(10, 7))
ri = ci = 0
for ci in range(weights.shape[0]):
kernel = get_temporal_kernel(t, np.array([[9.67, 20.03]]), np.array([[0.0, 1.0]]), weights[ci,:])
idx = np.abs(kernel) > 0.0
axes[ri, ci].plot(-t[idx]/1000, kernel[idx])
axes[ri, ci].set_ylim([-3.5, 10.0])
axes[ri, ci].text(0.05, 0.90, 'weights={}'.format(weights[ci,:]), horizontalalignment='left', verticalalignment='top', transform=axes[ri, ci].transAxes)
axes[0, 0].set_ylabel('effect of weights')
ri += 1
# kpeaks parameters controll the spread of both peaks, the second peak must have a bigger spread
for ci in range(kpeaks.shape[0]):
kernel = get_temporal_kernel(t, kpeaks[[ci],:], np.array([[0.0, 1.0]]), np.array([30.0, -20.0]))
idx = np.abs(kernel) > 0.0
axes[ri, ci].plot(-t[idx]/1000, kernel[idx])
axes[ri, ci].set_xlim([-0.15, 0.005])
axes[ri, ci].text(0.05, 0.90, 'kpeaks={}'.format(kpeaks[ci,:]), horizontalalignment='left', verticalalignment='top', transform=axes[ri, ci].transAxes)
axes[1, 0].set_ylabel('effects of kpeaks')
ri += 1
for ci in range(delays.shape[0]):
kernel = get_temporal_kernel(t, np.array([[9.67, 20.03]]), delays[[ci],:], np.array([30.0, -20.0]))
idx = np.abs(kernel) > 0.0
axes[ri, ci].plot(-t[idx]/1000, kernel[idx])
axes[ri, ci].set_xlim([-0.125, 0.001])
axes[ri, ci].text(0.05, 0.90, 'delays={}'.format(delays[ci,:]), horizontalalignment='left', verticalalignment='top', transform=axes[ri, ci].transAxes)
axes[2, 0].set_ylabel('effects of delays')
fig.show()
Besides, the function ff()
in line 108-117 of bmtk/simulator/filternet/lgnmodel/fitfuns.py
loops over an numpy array one-by-one, but this can be easily vectorized. Why does it choose a slower computation?