#56 Doesn't use the fact a 2d gaussian is separable and only works for mu=0.0
Nin17 opened this issue · 1 comments
Nin17 commented
The implementation in the answers for question 56 doesn't make use of the fact that a 2d gaussian is separable and can be implemented as two 1d gaussians - which is an order of magnitude faster on my machine. Also, the implementation is incorrect for mu != 0.0
.
import numpy as np
import matplotlib.pyplot as plt
def original(m= 10, n=10, sigma=1.0, mu=0.0):
X, Y = np.meshgrid(np.linspace(-1,1,m), np.linspace(-1,1,n))
D = np.sqrt(X*X+Y*Y)
G = np.exp(-( (D-mu)**2 / ( 2.0 * sigma**2 ) ) )
return G
def corrected(m=10, n=10, sigma=1.0, mu=0.0):
x, y = np.meshgrid(np.linspace(-1,1,m), np.linspace(-1,1,n))
g = np.exp(-( ((x-mu)**2 + (y-mu) **2) / ( 2.0 * sigma**2 ) ) )
return g
def separable(m=10, n = 10, sigma=1.0, mu=0.0):
x = np.linspace(-1,1,m)
y = np.linspace(-1,1,n)
gx = np.exp(-( (x-mu)**2 / ( 2.0 * sigma**2 ) ) )
gy = np.exp(-( (y-mu)**2 / ( 2.0 * sigma**2 ) ) )
return gy[:, None] * gx[None, :]
Test for mu != 0.0
assert np.allclose(original(mu=0.0, sigma=2.9), separable(mu=0.0, sigma=2.9))
assert np.allclose(original(mu=0.0, sigma=2.9), corrected(mu=0.0, sigma=2.9))
assert np.allclose(separable(mu=0.5, sigma=2.9), corrected(mu=0.5, sigma=2.9))
assert np.allclose(separable(mu=0.5, sigma=2.9), original(mu=0.5, sigma=2.9))
Raises:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Untitled-1.ipynb Cell 2 line 4
[2](vscode-notebook-cell:Untitled-1.ipynb?jupyter-notebook#W3sdW50aXRsZWQ%3D?line=1) assert np.allclose(original(mu=0.0, sigma=2.9), corrected(mu=0.0, sigma=2.9))
[3](vscode-notebook-cell:Untitled-1.ipynb?jupyter-notebook#W3sdW50aXRsZWQ%3D?line=2) assert np.allclose(separable(mu=0.5, sigma=2.9), corrected(mu=0.5, sigma=2.9))
----> [4](vscode-notebook-cell:Untitled-1.ipynb?jupyter-notebook#W3sdW50aXRsZWQ%3D?line=3) assert np.allclose(separable(mu=0.5, sigma=2.9), original(mu=0.5, sigma=2.9))
AssertionError:
Plot the functions:
plt.figure()
plt.title("Original")
plt.imshow(original(mu=0.5, sigma=5.0))
plt.colorbar()
plt.figure()
plt.title("Corrected")
plt.imshow(corrected(mu=0.5, sigma=5.0))
plt.colorbar()
plt.figure()
plt.title("Separable")
plt.imshow(separable(mu=0.5, sigma=5.0))
plt.colorbar()
Time the functions:
%timeit original(1000, 1001, 10)
%timeit corrected(1000, 1001, 10)
%timeit separable(1000, 1001, 10)
Output:
9.47 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.59 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
894 µs ± 32.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
rougier commented
Thanks. You're right for the corrected version and for the separable version even though I would prefer to have first the non separable version and the separable as an additionla remark. Could you make a PR (modifying the ktx files)?