keras-team/keras-cv

Ragged Tensor in Inference issue

Paryavi opened this issue · 6 comments

I fixed the ragged tensor issue on Keras CV (Btw I am using the latest Keras 3) Object detection using the following code;

def dict_to_tuple(inputs):
return inputs["images"], bounding_box.to_dense(
inputs["bounding_boxes"], max_boxes=32
)

And was able to train the model, and have positive mAPs, but when I wanna do inference using Object Detection documentation code, inference section, i.e. the cell copied below, I get the following ragged tensor error;

model.prediction_decoder = keras_cv.layers.NonMaxSuppression(
bounding_box_format="xywh",
from_logits=True,
iou_threshold=0.5,
confidence_threshold=0.75,
)

Error:

NotImplementedError Traceback (most recent call last)
in <cell line: 8>()
6 )
7
----> 8 visualize_detections(model, dataset=visualization_ds, bounding_box_format="xywh")

1 frames
/usr/local/lib/python3.10/dist-packages/keras_cv/src/bounding_box/to_ragged.py in to_ragged(bounding_boxes, sentinel, dtype)
55 """
56 if backend.supports_ragged() is False:
---> 57 raise NotImplementedError(
58 "bounding_box.to_ragged was called using a backend which does "
59 "not support ragged tensors. "

NotImplementedError: bounding_box.to_ragged was called using a backend which does not support ragged tensors. Current backend: tensorflow.

visualize_detections(model, dataset=visualization_ds, bounding_box_format="xywh")

@sachinprasadhs @fchollet @ianstenbit
To resolve the error in this thread, should I limit max_box to 32 again in inference(after model training) as well, how?
I guess the issue is with unbatching; Copying from documentation;

visualization_ds = eval_ds.unbatch()
visualization_ds = visualization_ds.ragged_batch(16)
visualization_ds = visualization_ds.shuffle(8)

FYI, the image data I use are small, so I padded to 640 by 640 pixels using a resizing layer for training, should I use the resizing layer somehow in inference as well?

@Paryavi I think your issue is ragged tensor with keras 3, which doesn't support yet.

NotImplementedError: bounding_box.to_ragged was called using a backend which does not support ragged tensors. Current backend: tensorflow.

Maybe, you can do padding instead.

preprocessor = keras.Sequential(
    layers=[
        keras_cv.layers.Resizing(
            input_shape, 
            input_shape,
            bounding_box_format=bbox_format,
            pad_to_aspect_ratio=True
        ),
    ],
)

def pad_fn(inputs):
    inputs["bounding_boxes"] = keras_cv.bounding_box.to_dense(
        inputs["bounding_boxes"], max_boxes=32
    )
    return inputs
visualization_ds = eval_ds.unbatch()
visualization_ds = visualization_ds.ragged_batch(16)
visualization_ds = visualization_ds.map(
        preprocessor, num_parallel_calls=tf.data.AUTOTUNE
    ) 
visualization_ds= visualization_ds.map(
        pad_fn, num_parallel_calls=tf.data.AUTOTUNE
    )
visualization_ds= visualization_ds.prefetch(tf.data.AUTOTUNE)

Thanks @innat-asj
I use the padding in the training the model section, and also I used your padding code after training the model (.fit), but I get this error when running the last part of your code;

visualization_ds = visualization_ds.map(
preprocessor, num_parallel_calls=tf.data.AUTOTUNE
)
visualization_ds= visualization_ds.map(
pad_fn, num_parallel_calls=tf.data.AUTOTUNE
)
visualization_ds= visualization_ds.prefetch(tf.data.AUTOTUNE)

Error:
TypeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 visualization_ds = visualization_ds.map(
2 preprocessor, num_parallel_calls=tf.data.AUTOTUNE
3 )
4 visualization_ds= visualization_ds.map(
5 pad_fn, num_parallel_calls=tf.data.AUTOTUNE

18 frames
/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
155 bound_signature = None
156 try:
--> 157 return fn(*args, **kwargs)
158 except Exception as e:
159 if hasattr(e, "_keras_call_info_injected"):

TypeError: Sequential.call() got multiple values for argument 'training'

@Paryavi what backend are you using? and what is the input tensor's backend? because JAX and pytorch does not support ragged tensors.

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.