MIC-DKFZ/napari-sam

3D support

wasserth opened this issue ยท 24 comments

Hello,
thanks for this tool. It looks really helpful. For medical images it would be great if 3D images would also be supported. Is this planned in the near future?

Hey,

Yes this is indeed planned for the future and we expect that we will be able to add support for 3D images within the next week actually! However, SAM is a 2D model so even if we apply it on 3D data the user will still have to click on every slice of an image.

Best
Karol

Having to click on every slice would be a great disadvantage. Instead of a 1-click segmentation it would often be a 30-click segmentation. Maybe you can use the same point also for the neighbouring slices. Users tend to click in the middle so the point should also be inside of the organ in more slices. If you go for 2 neighouring slices in each direction you could do 5 slices with one click.

I was thinking about how to add 3D support to our plugin and I think I found a way to segment 3D objects with just one click using some tricks on how I use SAMs prompting mechanic. But before I promise anything, I have to implement and test it first. So, stay tuned for more updates!

Thanks for this nice integration. I'm also interested in 3D support. One idea that I have seen implemented in other papers uses YZ and/or XZ planes to extract orthogonal 2d slices, and merges them together into a 3d object. If the voxel sizes are near isotropic, and depending on the object to segment, this can work quite nicely.

Hey,

I added an initial simple support for 3D images, which I plan to improve over the next week.
At the moment, you can load a 3D image and segment each slice with SAM individually.
I hope, I am able to release a new segmention mode next week that enables the user to segment entire 3D objects with SAM with a single click.

Best,
Karol

I am looking forward to test it out. I just tried to use it for 3d images, but I get an error. I installed version 0.3.6, and loaded an image stack with Open Folder and a labels image stack with the same dimension. I get the following traceback

File /opt/miniconda3/lib/python3.8/site-packages/segment_anything/utils/transforms.py:31, in ResizeLongestSide.apply_image(self=<segment_anything.utils.transforms.ResizeLongestSide object>, image=dask.array<getitem, shape=(640, 640, 3), dtype=u...chunksize=(640, 640, 1), chunktype=numpy.ndarray>)
     27 """
     28 Expects a numpy array with shape HxWxC in uint8 format.
     29 """
     30 target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
---> 31 return np.array(resize(to_pil_image(image), target_size))
        target_size = (1024, 1024)
        np = <module 'numpy' from '/opt/miniconda3/lib/python3.8/site-packages/numpy/__init__.py'>
        image = dask.array<getitem, shape=(640, 640, 3), dtype=uint8, chunksize=(640, 640, 1), chunktype=numpy.ndarray>

File /opt/miniconda3/lib/python3.8/site-packages/torchvision/transforms/functional.py:262, in to_pil_image(pic=dask.array<getitem, shape=(640, 640, 3), dtype=u...chunksize=(640, 640, 1), chunktype=numpy.ndarray>, mode=None)
    259     _log_api_usage_once(to_pil_image)
    261 if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
--> 262     raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
    264 elif isinstance(pic, torch.Tensor):
    265     if pic.ndimension() not in {2, 3}:

TypeError: pic should be Tensor or ndarray. Got <class 'dask.array.core.Array'>.

In Click / Instance mode. Can you help here?

Hey,

From the error message it seems that you loaded a dask array. Passing a dask array straight to SAM causes this error. I can add a fix to convert every passed image into a numpy array, but will only be able to do so in the evening today. Until then you would need to convert it into a numpy array yourself for example over the napari console:

import numpy as np

viewer.layers["your_image_layer_name"].data = np.asarray(viewer.layers["your_image_layer_name"].data)

Best,
Karol

I managed to make it further with your console example. It was kind of unexpected that when I open a folder, it is opening it as a dask array. In any case, I could compute the embedding for my stack. However, when I click on the stack, I get another traceback:


File /opt/miniconda3/lib/python3.8/site-packages/napari_sam/_widget.py:546, in SamWidget.do_click(self=<napari_sam._widget.SamWidget object>, coords=<class 'numpy.ndarray'> (3,) int64, is_positive=1)
    542     self.point_label = 0
    544 self.points[self.point_label].append(coords)
--> 546 self.run(self.points, self.point_label)
        self.point_label = 2
        self = <napari_sam._widget.SamWidget object at 0x7fd21306cc10>
        self.points = defaultdict(<class 'list'>, {2: [<class 'numpy.ndarray'> (3,) int64, <class 'numpy.ndarray'> (3,) int64]})
    547 with warnings.catch_warnings():
    548     warnings.filterwarnings("ignore", category=FutureWarning)

File /opt/miniconda3/lib/python3.8/site-packages/napari_sam/_widget.py:563, in SamWidget.run(self=<napari_sam._widget.SamWidget object>, points=defaultdict(<class 'list'>, {2: [<class 'numpy.n...(3,) int64, <class 'numpy.ndarray'> (3,) int64]}), point_label=2)
    560         labels = [label] * len(label_points)
    561         labels_flattended.extend(labels)
--> 563     prediction, predicted_slices = self.predict(points_flattened, labels_flattended)
        points_flattened = [<class 'numpy.ndarray'> (3,) int64, <class 'numpy.ndarray'> (3,) int64]
        labels_flattended = [1, 1]
        self = <napari_sam._widget.SamWidget object at 0x7fd21306cc10>
    564 else:
    565     prediction = np.zeros_like(self.label_layer.data)

File /opt/miniconda3/lib/python3.8/site-packages/napari_sam/_widget.py:583, in SamWidget.predict(self=<napari_sam._widget.SamWidget object>, points=<class 'numpy.ndarray'> (2, 3) int64, labels=[1, 1])
    581 def predict(self, points, labels):
    582     points = np.asarray(points)
--> 583     old_point, new_point = self.find_changed_point(np.asarray(self.old_points), points)
        points = <class 'numpy.ndarray'> (2, 3) int64
        np = <module 'numpy' from '/opt/miniconda3/lib/python3.8/site-packages/numpy/__init__.py'>
        self = <napari_sam._widget.SamWidget object at 0x7fd21306cc10>
        self.old_points = <class 'numpy.ndarray'> (2, 2) int64
    584     if self.image_layer.ndim == 2:
    585         self.sam_predictor.features = self.sam_features

File /opt/miniconda3/lib/python3.8/site-packages/napari_sam/_widget.py:725, in SamWidget.find_changed_point(self=<napari_sam._widget.SamWidget object>, old_points=<class 'numpy.ndarray'> (2, 2) int64, new_points=<class 'numpy.ndarray'> (2, 3) int64)
    723     old_point = old_points
    724 else:
--> 725     old_point = np.array([x for x in old_points if not np.any((x == new_points).all(1))])
        old_points = <class 'numpy.ndarray'> (2, 2) int64
        new_points = <class 'numpy.ndarray'> (2, 3) int64
        np = <module 'numpy' from '/opt/miniconda3/lib/python3.8/site-packages/numpy/__init__.py'>
    726 if len(old_points) == 0:
    727     new_point = new_points

File /opt/miniconda3/lib/python3.8/site-packages/napari_sam/_widget.py:725, in <listcomp>(.0=<iterator object>)
    723     old_point = old_points
    724 else:
--> 725     old_point = np.array([x for x in old_points if not np.any((x == new_points).all(1))])
        new_points = <class 'numpy.ndarray'> (2, 3) int64
        np = <module 'numpy' from '/opt/miniconda3/lib/python3.8/site-packages/numpy/__init__.py'>
        x = <class 'numpy.ndarray'> (2,) int64
    726 if len(old_points) == 0:
    727     new_point = new_points

AttributeError: 'bool' object has no attribute 'all'

Hey,

Is it possible for you to upload the image to a sharehoster and send me the link privately to karol.gotkowski@dkfz.de? This way I would be able to better reproduce the error. I would delete the image afterwards of course and not share it with anybody.

Best,
Karol

Sure, I can send you some data to your email. Not sure if there's an issue with my data though. I checked also the Napari version which is on the latest version. Also unrelated, but somehow napari/pytorch does not seem to free GPU memory, and I usually get CUDA out of memory. error when I try to reload the data, even when using only a small amount of data (even only one slice).

Thanks for sending me the data. I was able to reproduce the dask-array error on my side and could fix it.

Btw you do not need to provide a segmentation mask yourself. You can just add your image(s) and then add a labels layer by clicking the button shown in the image.

Regarding the CUDA out of memory error, I have not experienced this myself. Can you describe every step that you did until you received the error?

Do you also receive this error with the smallest SAM model vlt_b? What gpu do you have and how much gpu memory (VRAM) does it have?

grafik

Thanks for fixing it. The segmentation mask is basically a ground truth version that I want to augmented using painting with the plugin. At the moment, if I have a label layer with some annotations in it, and a activate e.g. the Click / Semantic mode, using this label layer my existing annotations get removed. Perhaps it makes sense to keep these existing labels in a separate layer, and then merge the automatically plugin-generated labels later. Do you know if there an easy way to merge two layers with the UI?

In my case, I want to label a lot objects (of the same semantic type) in the images. What I find is that after I label one object, I always need to manually increase the label index in the left panel. Have you thought about adding a shortcut to move to the next free label quickly e.g. with a key shortcut, or with a mouse click (e.g. Right Click)? This would improve the usability quite a bit and speed things up.

I have an 8GB GPU, with about 6-7GB of RAM free. Things are now better with using the smaller model vit_b and not loading 100 sections at the same time in the 3d stack.

To answer my own question about merging label layers, I found this which mentions this plugin.

And with regards to the OP question about propagating in 3d, I found this plugin, but I haven't tested it myself.

Cool that you found the merging and interpolation plugins! They really seem quite useful :)

In my case, I want to label a lot objects (of the same semantic type) in the images. What I find is that after I label one object, I always need to manually increase the label index in the left panel. Have you thought about adding a shortcut to move to the next free label quickly e.g. with a key shortcut, or with a mouse click (e.g. Right Click)? This would improve the usability quite a bit and speed things up.

There is actually a napari shortcut. If you press "M" then you will automatically select max_label+1.

I played further with the plugin and found that for my 3d use case, I'd like to overwrite some mouse controls. e.g.

  • using left click, and ctrl-left click for picking with the plugin (instead of middle mouse button), and middle mouse down for panning
  • use the mousewheel for changing the z section of the stack

I did not see any preference for mouse bindings in napari. And it's probably also hard-coded in your plugin. Do you know if I would have to change the napari and plugin code for this in a cloned repo, or is there any other simpler way?

Yes, the middle click action to place a point for SAM is hardcoded in my plugin. I figured it is the most convinient way for most users as that enables the user to:

  • Place a positive click with middle mouse button
  • Place a negative click with holding down Control and clicking the middle mouse button
  • Zoom the image by holding down right click
  • Pan the image by holding down left click
  • Scroll through the z-section of the image by holding down Control and using the mousewheel

use the mousewheel for changing the z section of the stack

This is actually supported by Napari by default with Control + Mousewheel. If it seems to not work then click once into the image as it is probably not focused/active.

using left click, and ctrl-left click for picking with the plugin (instead of middle mouse button), and middle mouse down for panning

You mean that the user should be able to pick the color/label of the current mouse position with Control+Left Click? That actually sounds like a good idea :) I think I will implement that tomorrow ๐Ÿ‘

With these things I think there is not a perfect solution for everybody, so the ability to customize is best. Different datasets (2D, 3D, different image content), different labeling workflows, available hardware, and what the person is used to in terms of mouse interaction and key shortcuts has an influence.

If one plans to spend a lot of time on a certain task, with many repetitive interactions, it makes sense to minimize any friction as much as possible.

For instance I need to browse a lot in the Z direction, so having to press to Ctrl to do that is too much, I just want to use the mousewheel.

With my mouse I can press the scroll wheel as the middle button. Probably not every mouse has a middle button. And the issue is that for my labeling workflow, the main interaction is doing a positive click. So I want to do that with the left mouse button. As, pressing the middle button requires much more force. A key modifier for the negative click would be fine, like using ctrl-click.

The second most common operation is panning. I am used to do that from other tools with the left mouse button, holding the button down and moving pans the data.

I guess if the ui does not expose a way to customize this, I'd have to implement it in my own fork. Not optimal, but worth doing. :)

That sounds like a very unique workflow and would require probably some changes on the side of Napari. Napari has some limited options to change hotkeys and interactions under File -> Preferences -> Shortcuts -> Group Viewer/image/Layer/... However, it seems that it is currently not possible to change the mouse bindings. The best solution for you would probably be to create a feature request https://github.com/napari/napari/issues.

If you want to override the hotkey/interaction settings from my plugin, it is probably best to fork the repo and modify this part of the code:

def callback_click(self, layer, event):
if self.annotator_mode == AnnotatorMode.CLICK:
data_coordinates = self.image_layer.world_to_data(event.position)
coords = np.round(data_coordinates).astype(int)
if (not CONTROL in event.modifiers) and event.button == 3: # Positive middle click
self.do_click(coords, 1)
yield
elif CONTROL in event.modifiers and event.button == 3: # Negative middle click
self.do_click(coords, 0)
yield
elif (not CONTROL in event.modifiers) and event.button == 1 and self.points_layer is not None and len(self.points_layer.data) > 0:
# Find the closest point to the mouse click
distances = np.linalg.norm(self.points_layer.data - coords, axis=1)
closest_point_idx = np.argmin(distances)
closest_point_distance = distances[closest_point_idx]
# Select the closest point if it's within self.point_size pixels of the click
if closest_point_distance <= self.point_size:
self.points_layer.selected_data = {closest_point_idx}
else:
self.points_layer.selected_data = set()
yield
elif (CONTROL in event.modifiers) and event.button == 1:
picked_label = self.label_layer.data[slicer(self.label_layer.data, coords)]
self.label_layer.selected_label = picked_label
yield

Thanks, ok, I will look into it. Just for reference:

Someone had asked about this in the point annotation interface:
napari/napari#2146

As for custom mouse binding, someone else had the request
napari/napari#252 (comment)
and implemented in
napari/napari#544

Hi, team. Thanks for this great work, this is first SAM project i found could directly load nii file. I was wonder is there any update for 3d support?

Hey,

there is currently slice-based 3D support meaning that you can segment individual slices of 3D data very easily.
I originally planned to extend SAM to full 3D support through better prompting, but other projects are more important atm.

Hi,
Thanks team for your work !
Alternatively, you could try my plugin, which propagates labels between remote annotated slices. You could therefore use napari-sam to segment a few slices (at least the first and last) and then obtain the 3D segmentation by propagation.

@wasserth and @Ike-yang - you could also try micro-sam, which implements 3D predictions (it's using the 2D SegmentAnything model and applying that slice by slice, which is convenient for the user)