This is due to the transition to Tensorflow 2, and the original DeepExplain package not supporting TF2 out of the box. There is an open pull request (marcoancona/DeepExplain#55) that provides support for TF2 as long as you disable eager execution:
import tensorflow as tf
(the rest of your code)
Here's a code snippet that works out-of-the-box with the above pull request (using the MNE sample dataset):
# import tensorflow and disable eager execution right up front
import tensorflow as tf
import numpy as np
# mne imports
import mne
from mne import io
from mne.datasets import sample
# EEGNet-specific imports
from EEGModels import EEGNet
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras import backend as K
from tensorflow.keras.models import Model
from deepexplain.tensorflow import DeepExplain
# while the default tensorflow ordering is 'channels_last' we set it here
# to be explicit in case if the user has changed the default ordering
##################### Process, filter and epoch the data ######################
data_path = sample.data_path()
# Set parameters and read data
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
tmin, tmax = -0., 1
event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)
# Setup for reading the raw data
raw = io.Raw(raw_fname, preload=True, verbose=False)
raw.filter(2, None, method='iir') # replace baselining with high-pass
events = mne.read_events(event_fname)['bads'] = ['MEG 2443'] # set bad channels
picks = mne.pick_types(, meg=False, eeg=True, stim=False, eog=False,
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
picks=picks, baseline=None, preload=True, verbose=False)
labels =[:, -1]
# extract raw data. scale by 1000 due to scaling sensitivity in deep learning
X = epochs.get_data()*1000 # format is in (trials, channels, samples)
y = labels
kernels, chans, samples = 1, 60, 151
# take 50/25/25 percent of the data to train/validate/test
X_train = X[0:144,]
Y_train = y[0:144]
X_validate = X[144:216,]
Y_validate = y[144:216]
X_test = X[216:,]
Y_test = y[216:]
# convert labels to one-hot encodings.
Y_train = np_utils.to_categorical(Y_train-1)
Y_validate = np_utils.to_categorical(Y_validate-1)
Y_test = np_utils.to_categorical(Y_test-1)
# convert data to NHWC (trials, channels, samples, kernels) format. Data
# contains 60 channels and 151 time-points. Set the number of kernels to 1.
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other
# model configurations may do better, but this is a good starting point)
model = EEGNet(nb_classes = 4, Chans = chans, Samples = samples,
dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16,
dropoutType = 'Dropout')
# compile the model and set the optimizers
model.compile(loss='categorical_crossentropy', optimizer='adam',
metrics = ['accuracy'])
# count number of parameters in the model
numParams = model.count_params()
# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.h5', verbose=1,
# if the classification task was imbalanced (significantly more trials in one
# class versus the others) you can assign a weight to each class during
# optimization to balance it out. This data is approximately balanced so we
# don't need to do this, but is shown here for illustration/completeness.
# the syntax is {class_1:weight_1, class_2:weight_2,...}. Here just setting
# the weights all to be 1
class_weights = {0:1, 1:1, 2:1, 3:1}
fittedModel =, Y_train, batch_size = 16, epochs = 5,
verbose = 2, validation_data=(X_validate, Y_validate),
callbacks=[checkpointer], class_weight = class_weights)
with DeepExplain(session = K.get_session()) as de:
input_tensor = model.layers[0].input
fModel = Model(inputs = input_tensor, outputs = model.layers[-2].output)
target_tensor = fModel(input_tensor)
# can use epsilon-LRP as well if you like.
attributions = de.explain('deeplift', target_tensor * Y_test, input_tensor, X_test)
# attributions = de.explain('elrp', target_tensor * Y_test, input_tensor, X_test)
Alternatively, you could manually fix this by editing /deepexplain/tensorflow/ directly, although this is a pretty bad hack:
- Change
- Change
- Change
I've verified this also works (not extensively tested however), although the above PR is the better route.
Very good that works thank you!