keras-team/tf-keras

Feature: Flops calculation

Closed this issue · 2 comments

Reopening from

Currently the adopted solutions (from community)

import tensorflow as tf
from tensorflow.python.profiler import model_analyzer, option_builder

model = tf.keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False
) 

input_signature = [
    tf.TensorSpec(
        shape=(1, *params.shape[1:]), 
        dtype=params.dtype, 
        name=params.name
    ) for params in model.inputs
]
forward_graph = tf.function(model, input_signature).get_concrete_function().graph
options = option_builder.ProfileOptionBuilder.float_operation()
graph_info = model_analyzer.profile(forward_graph, options=options)
flops = graph_info.total_float_ops // 2
flops # 1925897756 

This issue does not seem to be related to keras.
Please file TF related issues here - https://github.com/tensorflow/tensorflow/issues

@divyashreepathihalli

This issue does not seem to be related to keras.

How come? How did you evaluate it is not fit here?

I updated the title, it is not about tf 2 only but for all backed, tf, torch, jax.