keras-team/keras-io

Keras Keypoint detection example doesn't converge

Closed this issue · 3 comments

Issue Type

Bug

Source

source

Keras Version

Keras 3

Custom Code

No

OS Platform and Distribution

Linux

Python version

3.9

GPU model and memory

No response

Current Behavior?

There is a keypoint detection guide on the Keras website at https://keras.io/examples/vision/keypoint_detection/, the example walks through training a key point detection model on the stanford dogs dataset. At the end of the example, it shows some of the example key points predicted by the model, with the comment "Predictions will likely improve with more training.":

image

I have tried running additional training, but the results are the same, and don't seem to improve. I've tried training on a custom dataset using the model architecture from the example, and the results are similar - the predicted key points don't correlate to the image.

That makes me suspect that there an issue with the model architecture proposed in the example:

def get_model():
    # Load the pre-trained weights of MobileNetV2 and freeze the weights
    backbone = keras.applications.MobileNetV2(
        weights="imagenet",
        include_top=False,
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
    )
    backbone.trainable = False

    inputs = layers.Input((IMG_SIZE, IMG_SIZE, 3))
    x = keras.applications.mobilenet_v2.preprocess_input(inputs)
    x = backbone(x)
    x = layers.Dropout(0.3)(x)
    x = layers.SeparableConv2D(
        NUM_KEYPOINTS, kernel_size=5, strides=1, activation="relu"
    )(x)
    outputs = layers.SeparableConv2D(
        NUM_KEYPOINTS, kernel_size=3, strides=1, activation="sigmoid"
    )(x)

    return keras.Model(inputs, outputs, name="keypoint_detector")

Is it possible that the model architecture suggested is not suitable for key point detection? I'm very interested in a basic keypoint detection architecture, but the example/guide is unfortunately not helpful.

If you can provide some pointers, I can take a pass at improving the example.

Thanks in advance

Standalone code to reproduce the issue or tutorial link

examples/vision/keypoint_detection.py

Relevant log output

No response

Hi,

The tutorial is for demonstration purpose only, and it is not fine tuned to return the best outcome.

Feel free to fine tune the model as per the suggestion in the comment below.

Going further
Try using other augmentation transforms from imgaug to investigate how that changes the results.
Here, we transferred the features from the pre-trained network linearly that is we did not fine-tune it. You are encouraged to fine-tune it on this task and see if that improves the performance. You can also try different architectures and see how they affect the final performance.

If you get the bettewr convergence with the modification in the model architecture, you can create a PR to make the same changes. Thanks!

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.