peng-lab/BaSiCPy

Possibly reduce calls to svd

Closed this issue · 7 comments

https://github.com/peng-lab/PyBaSiC/blob/01d8fe1ae86f2a09ced2cd710bd3e043036deea0/pybasic/tools/inexact_alm_rspca_l1.py#L43

I might be missing something, but it looks like this only needs to run once per call to basic. It's input images does not change during the course of the run. If that is right, we could calculate the singular values once instead of during every re weighting iteration.

I agree, a close code appears here, but I think we need some refactoring.
https://github.com/peng-lab/PyBaSiC/blob/01d8fe1ae86f2a09ced2cd710bd3e043036deea0/pybasic/_background.py#L78

Tim is correct here. I'll run some code profiling to see how much it bogs down the code.

Based on the outputs from the profile, I don't think the svd is too much of a concern. We should take a look at accelerating basic mathematical operations using in-place/accelerated functions where possible. This also suggests that we should get a pretty reasonable speed boost when moving to GPU.

Here is the output from the profiler:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    35                                           @profile
    36                                           def inexact_alm_rspca_l1(
    37                                               images: np.ndarray,
    38                                               lambda_darkfield: float,
    39                                               lambda_flatfield: float,
    40                                               get_darkfield: bool,
    41                                               optimization_tol: float,
    42                                               max_iterations: int,
    43                                               weight: np.ndarray = None,
    44                                           ):
    45                                           
    46         2         10.0      5.0      0.0      p = images.shape[0]
    47         2          5.0      2.5      0.0      q = images.shape[1]
    48         2          6.0      3.0      0.0      m = p * q
    49         2          4.0      2.0      0.0      n = images.shape[2]
    50         2      12386.0   6193.0      0.2      images = np.reshape(images, (m, n), order="F")
    51                                           
    52         2          7.0      3.5      0.0      if weight is not None:
    53         2       6079.0   3039.5      0.1          weight = np.reshape(weight, (m, n), order="F")
    54                                               else:
    55                                                   weight = np.ones_like(images)
    56         4      39594.0   9898.5      0.6      svd = np.linalg.svd(
    57         2          4.0      2.0      0.0          images, False, False
    58                                               )  # TODO: Is there a more efficient implementation of SVD?
    59         2         20.0     10.0      0.0      norm_two = svd[0]
    60         2          6.0      3.0      0.0      Y1 = 0
    61         2          5.0      2.5      0.0      ent1 = 1
    62         2          5.0      2.5      0.0      ent2 = 10
    63                                           
    64         2       1445.0    722.5      0.0      A1_hat = np.zeros_like(images)
    65         2         65.0     32.5      0.0      A1_coeff = np.ones((1, images.shape[1]))
    66                                           
    67         2       3376.0   1688.0      0.1      E1_hat = np.zeros_like(images)
    68         2       1578.0    789.0      0.0      W_hat = dct2d(np.zeros((p, q)).T)
    69         2         10.0      5.0      0.0      mu = 12.5 / norm_two
    70         2         10.0      5.0      0.0      mu_bar = mu * 1e7
    71         2          6.0      3.0      0.0      rho = 1.5
    72         2        983.0    491.5      0.0      d_norm = np.linalg.norm(images, ord="fro")
    73                                           
    74         2         31.0     15.5      0.0      A_offset = np.zeros((m, 1))
    75         2       1715.0    857.5      0.0      B1_uplimit = np.min(images)
    76         2          7.0      3.5      0.0      B1_offset = 0
    77                                           
    78         2         28.0     14.0      0.0      A_inmask = np.zeros((p, q))
    79         4         37.0      9.2      0.0      A_inmask[
    80         4        209.0     52.2      0.0          int(np.round(p / 6) - 1) : int(np.round(p * 5 / 6)),
    81         2        104.0     52.0      0.0          int(np.round(q / 6) - 1) : int(np.round(q * 5 / 6)),
    82         2          6.0      3.0      0.0      ] = 1
    83                                           
    84                                               # main iteration loop starts
    85         2          6.0      3.0      0.0      iter = 0
    86         2          5.0      2.5      0.0      converged = False
    87                                           
    88        72        220.0      3.1      0.0      while not converged:
    89        70        261.0      3.7      0.0          iter += 1
    90                                           
    91        70        429.0      6.1      0.0          if len(A1_coeff.shape) == 1:
    92        68       4378.0     64.4      0.1              A1_coeff = np.expand_dims(A1_coeff, 0)
    93        70        299.0      4.3      0.0          if len(A_offset.shape) == 1:
    94                                                       A_offset = np.expand_dims(A_offset, 1)
    95        70      44338.0    633.4      0.7          W_idct_hat = idct2d(W_hat.T)
    96        70     560000.0   8000.0      8.4          A1_hat = np.dot(np.reshape(W_idct_hat, (-1, 1), order="F"), A1_coeff) + A_offset
    97                                           
    98        70     844410.0  12063.0     12.6          temp_W = (images - A1_hat - E1_hat + (1 / mu) * Y1) / ent1
    99        70       2955.0     42.2      0.0          temp_W = np.reshape(temp_W, (p, q, n), order="F")
   100        70      88308.0   1261.5      1.3          temp_W = np.mean(temp_W, axis=2)
   101        70      46870.0    669.6      0.7          W_hat = W_hat + dct2d(temp_W.T)
   102       140      17213.0    123.0      0.3          W_hat = np.maximum(W_hat - lambda_flatfield / (ent1 * mu), 0) + np.minimum(
   103        70       1783.0     25.5      0.0              W_hat + lambda_flatfield / (ent1 * mu), 0
   104                                                   )
   105        70      53415.0    763.1      0.8          W_idct_hat = idct2d(W_hat.T)
   106        70        367.0      5.2      0.0          if len(A1_coeff.shape) == 1:
   107                                                       A1_coeff = np.expand_dims(A1_coeff, 0)
   108        70        262.0      3.7      0.0          if len(A_offset.shape) == 1:
   109                                                       A_offset = np.expand_dims(A_offset, 1)
   110        70     600801.0   8582.9      9.0          A1_hat = np.dot(np.reshape(W_idct_hat, (-1, 1), order="F"), A1_coeff) + A_offset
   111        70     754335.0  10776.2     11.3          E1_hat = images - A1_hat + (1 / mu) * Y1 / ent1
   112        70    2003176.0  28616.8     30.0          E1_hat = _shrinkageOperator(E1_hat, weight / (ent1 * mu))
   113        70     395243.0   5646.3      5.9          R1 = images - E1_hat
   114        70     194674.0   2781.1      2.9          A1_coeff = np.mean(R1, 0) / np.mean(R1)
   115        70       1259.0     18.0      0.0          A1_coeff[A1_coeff < 0] = 0
   116                                           
   117        70        236.0      3.4      0.0          if get_darkfield:
   118                                                       validA1coeff_idx = np.where(A1_coeff < 1)
   119                                           
   120                                                       B1_coeff = (
   121                                                           np.mean(
   122                                                               R1[
   123                                                                   np.reshape(W_idct_hat, -1, order="F")
   124                                                                   > np.mean(W_idct_hat) - 1e-6
   125                                                               ][:, validA1coeff_idx[0]],
   126                                                               0,
   127                                                           )
   128                                                           - np.mean(
   129                                                               R1[
   130                                                                   np.reshape(W_idct_hat, -1, order="F")
   131                                                                   < np.mean(W_idct_hat) + 1e-6
   132                                                               ][:, validA1coeff_idx[0]],
   133                                                               0,
   134                                                           )
   135                                                       ) / np.mean(R1)
   136                                                       k = np.array(validA1coeff_idx).shape[1]
   137                                                       temp1 = np.sum(A1_coeff[validA1coeff_idx[0]] ** 2)
   138                                                       temp2 = np.sum(A1_coeff[validA1coeff_idx[0]])
   139                                                       temp3 = np.sum(B1_coeff)
   140                                                       temp4 = np.sum(A1_coeff[validA1coeff_idx[0]] * B1_coeff)
   141                                                       temp5 = temp2 * temp3 - temp4 * k
   142                                                       if temp5 == 0:
   143                                                           B1_offset = 0
   144                                                       else:
   145                                                           B1_offset = (temp1 * temp3 - temp2 * temp4) / temp5
   146                                                       # limit B1_offset: 0<B1_offset<B1_uplimit
   147                                           
   148                                                       B1_offset = np.maximum(B1_offset, 0)
   149                                                       B1_offset = np.minimum(B1_offset, B1_uplimit / np.mean(W_idct_hat))
   150                                           
   151                                                       B_offset = B1_offset * np.reshape(W_idct_hat, -1, order="F") * (-1)
   152                                           
   153                                                       B_offset = B_offset + np.ones_like(B_offset) * B1_offset * np.mean(
   154                                                           W_idct_hat
   155                                                       )
   156                                                       A1_offset = np.mean(R1[:, validA1coeff_idx[0]], axis=1) - np.mean(
   157                                                           A1_coeff[validA1coeff_idx[0]]
   158                                                       ) * np.reshape(W_idct_hat, -1, order="F")
   159                                                       A1_offset = A1_offset - np.mean(A1_offset)
   160                                                       A_offset = A1_offset - np.mean(A1_offset) - B_offset
   161                                           
   162                                                       # smooth A_offset
   163                                                       W_offset = dct2d(np.reshape(A_offset, (p, q), order="F").T)
   164                                                       W_offset = np.maximum(
   165                                                           W_offset - lambda_darkfield / (ent2 * mu), 0
   166                                                       ) + np.minimum(W_offset + lambda_darkfield / (ent2 * mu), 0)
   167                                                       A_offset = idct2d(W_offset.T)
   168                                                       A_offset = np.reshape(A_offset, -1, order="F")
   169                                           
   170                                                       # encourage sparse A_offset
   171                                                       A_offset = np.maximum(
   172                                                           A_offset - lambda_darkfield / (ent2 * mu), 0
   173                                                       ) + np.minimum(A_offset + lambda_darkfield / (ent2 * mu), 0)
   174                                                       A_offset = A_offset + B_offset
   175                                           
   176        70     506362.0   7233.7      7.6          Z1 = images - A1_hat - E1_hat
   177        70     253366.0   3619.5      3.8          Y1 = Y1 + mu * Z1
   178        70       1786.0     25.5      0.0          mu = np.minimum(mu * rho, mu_bar)
   179                                           
   180                                                   # Stop Criterion
   181        70     241313.0   3447.3      3.6          stopCriterion = np.linalg.norm(Z1, ord="fro") / d_norm
   182        70        444.0      6.3      0.0          if stopCriterion < optimization_tol:
   183         2          7.0      3.5      0.0              converged = True
   184                                           
   185        70        266.0      3.8      0.0          if not converged and iter >= max_iterations:
   186                                                       print("Maximum iterations reached")
   187                                                       converged = True
   188                                           
   189         2         40.0     20.0      0.0      A_offset = np.squeeze(A_offset)
   190         2        239.0    119.5      0.0      A_offset = A_offset + B1_offset * np.reshape(W_idct_hat, -1, order="F")
   191                                           
   192         2          7.0      3.5      0.0      return A1_hat, E1_hat, A_offset

That's surprising, but exciting... looks like we could get a big boost from using jit?

It's been my general experience that jit really only helps to overcome things that python is intrinsically bad at, like for loops. Things that are vectorized like numpy arrays generally rely on optimized libraries under the hood, so jit generally doesn't help much.

We actually removed jit from a new implementation of flowfield calculations in cellpose because we found a way to vectorize the operations so you didn't need to trace vectors in a for loop. You could just do massive matrix operations and it ran faster than jit.

Seems svd is not necessary, but np.linalg.norm can provide the spectral norm.