google-research/sofima

Non-overlapping tiles

Matthijs-utf8 opened this issue · 1 comments

I am using a dataset that sometimes contains tiles that don't overlap. I've added an image of the coarse offsets. We can see that one offset (the offset between tiles (0, 0) and (0, 1) is not computed).

image

I use the workflow from the em_stitching Colab notebook. I compute the coarse offsets and mesh with this code:

from sofima import stitch_rigid
cx, cy = stitch_rigid.compute_coarse_offsets(grid_size, tile_map)

f, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].quiver((0, 1, 2), (0, 1, 2), cx[0, 0, ...], cx[1, 0, ...])
ax[0].set_ylim(-0.5, 2.5)
ax[0].set_xlim(-0.5, 1.5)
ax[0].set_title('horizontal NNs')
ax[1].quiver((0, 1, 2), (0, 1, 2), cy[0, 0, ...], cy[1, 0, ...])
ax[1].set_ylim(-0.5, 1.5)
ax[1].set_xlim(-0.5, 2.5)
ax[1].set_title('vertical NNs')

coarse_mesh = stitch_rigid.optimize_coarse_mesh(cx, cy)

Then I use that to calculate the fine mesh with this code:

from sofima import stitch_elastic, flow_utils, mesh

stride = 20
cx = np.squeeze(cx)
cy = np.squeeze(cy)
fine_x, offsets_x = stitch_elastic.compute_flow_map(tile_map, cx, 0, stride=(stride, stride), batch_size=4)  # (x,y) -> (x+1,y)
fine_y, offsets_y = stitch_elastic.compute_flow_map(tile_map, cy, 1, stride=(stride, stride), batch_size=4)  # (x,y) -> (x,y+1)

# "min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, "max_deviation": 5, "max_magnitude": 0}
kwargs = {"min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, "max_deviation": 5, "max_magnitude": 0}
fine_x = {k: flow_utils.clean_flow(v[:, np.newaxis, ...], **kwargs)[:, 0, :, :] for k, v in fine_x.items()}
fine_y = {k: flow_utils.clean_flow(v[:, np.newaxis, ...], **kwargs)[:, 0, :, :] for k, v in fine_y.items()}

kwargs = {"min_patch_size": 10, "max_gradient": -1, "max_deviation": -1}
fine_x = {k: flow_utils.reconcile_flows([v[:, np.newaxis, ...]], **kwargs)[:, 0, :, :] for k, v in fine_x.items()}
fine_y = {k: flow_utils.reconcile_flows([v[:, np.newaxis, ...]], **kwargs)[:, 0, :, :] for k, v in fine_y.items()}

data_x = (cx, fine_x, offsets_x)
data_y = (cy, fine_y, offsets_y)

fx, fy, x, nbors, key_to_idx = stitch_elastic.aggregate_arrays(
    data_x, data_y, list(tile_map.keys()),
    coarse_mesh[:, 0, ...], stride=(stride, stride),
    tile_shape=next(iter(tile_map.values())).shape)

@jax.jit
def prev_fn(x):
    target_fn = ft.partial(stitch_elastic.compute_target_mesh, x=x, fx=fx,
                           fy=fy, stride=(stride, stride))
    x = jax.vmap(target_fn)(nbors)
    return jnp.transpose(x, [1, 0, 2, 3])

config = mesh.IntegrationConfig(dt=0.001, gamma=0., k0=0.01, k=0.1, stride=stride,
                                num_iters=1000, max_iters=20000, stop_v_max=0.001,
                                dt_max=100, prefer_orig_order=True,
                                start_cap=0.1, final_cap=10., remove_drift=True)

x, ekin, t = mesh.relax_mesh(x, None, config, prev_fn=prev_fn)

When I run the fine alignment cell, I get this error:


OverflowError Traceback (most recent call last)
in <cell line: 16>()
14
15 # Compute flow maps for horizontal and vertical directions
---> 16 fine_x, offsets_x = stitch_elastic.compute_flow_map(tile_map, cx, 0, stride=(stride, stride), batch_size=4) # (x,y) -> (x+1,y)
17 fine_y, offsets_y = stitch_elastic.compute_flow_map(tile_map, cy, 1, stride=(stride, stride), batch_size=4) # (x,y) -> (x,y+1)
18 /usr/local/lib/python3.10/dist-packages/sofima/stitch_elastic.py in compute_flow_map(tile_map, offset_map, axis, patch_size, stride, batch_size)
241 rounded_offset = stride[::-1] * np.round(offset / stride[::-1])
242
--> 243 overlap = -int(offset[axis])
244 overlap = pre.shape[1 - axis] - (
245 (pre.shape[1 - axis] - overlap) // stride[1 - axis] * stride[1 - axis]
OverflowError: cannot convert float infinity to integer


I believe the problem might have to do with non-overlapping tiles, because I have had the same problem earlier on a set where some of the tiles don't overlap, but with this specific grid I'm sure that this is not the case. When I check the tiles by hand in Fiji, they all seem to have overlapping features. Could someone help me to figure out how to solve this problem?

Hi,

Apologies for the slow response! If the tiles are truly non-overlapping, or if the overlap is too small to reliably estimate, a possible workaround is something along the following lines:

  cx, cy = stitch_rigid.compute_coarse_offsets(
      tuple(tile_map.shape),
      tiles,
      overlaps_xy=(list(req.overlap_x), list(req.overlap_y)),
  )
  cx = stitch_rigid.interpolate_missing_offsets(cx, -1)
  cy = stitch_rigid.interpolate_missing_offsets(cy, -2)
  cx[np.isinf(cx)] = np.nan
  cy[np.isinf(cy)] = np.nan

  coarse_mesh = stitch_rigid.optimize_coarse_mesh(cx, cy)

Which will attempt to estimate the coarse position of the tiles by using the relative offset information where available, and extrapolating it from neighboring tiles otherwise.

Could you please try this, and see if the estimate value in cx makes sense relatively to what you're able to estimate manually for that tile?

It's also possible that the default settings in compute_coarse_offsets() do not work well for your data, in which case it could make sense to adjust some of them (overlaps_xy, min_overlaps, min_range, filter_size).