Feature: Flops calculation
Closed this issue · 2 comments
innat commented
Reopening from
- #552
- tensorflow/tensorflow#32809
- keras-team/keras#12970
- keras-team/keras#5625
- tensorflow/tensorflow#39834
Currently the adopted solutions (from community)
import tensorflow as tf
from tensorflow.python.profiler import model_analyzer, option_builder
model = tf.keras.applications.Xception(
weights='imagenet',
input_shape=(150, 150, 3),
include_top=False
)
input_signature = [
tf.TensorSpec(
shape=(1, *params.shape[1:]),
dtype=params.dtype,
name=params.name
) for params in model.inputs
]
forward_graph = tf.function(model, input_signature).get_concrete_function().graph
options = option_builder.ProfileOptionBuilder.float_operation()
graph_info = model_analyzer.profile(forward_graph, options=options)
flops = graph_info.total_float_ops // 2
flops # 1925897756
divyashreepathihalli commented
This issue does not seem to be related to keras.
Please file TF related issues here - https://github.com/tensorflow/tensorflow/issues
innat commented
This issue does not seem to be related to keras.
How come? How did you evaluate it is not fit here?
I updated the title, it is not about tf 2 only but for all backed, tf, torch, jax.