sicara/tf-explain

[NOTE] tf version issue: tf.function-decorated function tried to create variables on non-first

vscv opened this issue · 1 comments

vscv commented

This error occurred when using tf2.3. somehow tf-explain work just fine with tf2.2.

`
import tensorflow as tf
from tf_explain.core.grad_cam import GradCAM

IMAGE_PATH = './0_in.jpg'

all default from github

Load pretrained model or your own

model = tf.keras.applications.vgg16.VGG16(weights="imagenet", include_top=True)

Load a sample image (or multiple ones)

img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
img = tf.keras.preprocessing.image.img_to_array(img)
data = ([img], None)

Start explainer

explainer = GradCAM()
grid = explainer.explain(data, model, class_index=281) # 281 is the tabby cat index in ImageNet

explainer.save(grid, ".", "grad_cam.png")
`

you need to do a little modification of the tf-explain code.
There are two ways:

  1. remove tf.function from "def get_gradients_and_filters" in grad_cam.py
    however if you want to retain tf.fucntion
  2. move grad_model generation from "def get_gradients_and_filters" to "def explain" and pass it via function.