Train the net end to end fails
qingzhouzhen opened this issue · 3 comments
HI I try to train the network end to end, then it run into error, it complains:
LookupError: No gradient defined for operation 'model_3/feature_resizer_20/ResizeArea' (op type: ResizeArea)
it is caused by feature resize operation in the MLSP ImageResizer = Lambda(lambda x: K.tf.image.resize_area(x, pool_size), name='feature_resizer')
in application.py,
How did you go throuth this problem? could you give me some advice? @subpic
Hi Qing Zhou Zhen,
Training the wide MLSP network end to end is not possible, as you've noticed. This only works for the narrow MLSP that uses GAP
instead of resize_area
to pool the activations.
Best,
Vlad
I want to train the net end to end, first I write the train code refer to the exists train code as flow(because I have not find code train end to end):
import pandas as pd
import os
from kutils import applications as apps
from keras.models import Model
from kutils import tensor_ops as ops
from kutils import model_helper as mh
def model_def():
base_model = apps.model_inception_multigap(input_shape)
pred = apps.fc_layers(base_model.output,
name='head',
fc_sizes=fc_sizes,
dropout_rates=dropout_rates,
batch_norm=bn)
model = Model(inputs=base_model.input, outputs=pred)
return model
if __name__ == '__main__':
input_shape = (None, None, 3)
fc1_size = 2048
bn = 2
fc_sizes = [fc1_size, fc1_size / 2, fc1_size / 8, 1]
dropout_rates = [0.25, 0.25, 0.5, 0]
image_size = '[orig]'
root_path = '/mnt/home/research/ava-mlsp/'
images_path = root_path + 'images/'
model_name = "irnv2_mlsp_narrow_orig"
dataset = root_path + 'metadata/AVA_data_official_test.csv';
ids = pd.read_csv(dataset)
loss = 'MSE'
model = model_def()
monitor_metric = 'val_loss';
monitor_mode = 'min'
metrics = ["MAE", ops.plcc_tf]
outputs = 'MOS'
gen_params = dict(batch_size=128,
data_path=images_path,
input_shape=('orig',),
inputs='image_name',
outputs=outputs,
fixed_batches=True)
helper = mh.ModelHelper(model, model_name, ids,
max_queue_size=128,
loss=loss,
metrics=metrics,
monitor_metric=monitor_metric,
monitor_mode=monitor_mode,
early_stop_patience=5,
multiproc=True, workers=3,
logs_root=root_path + 'logs',
models_root=root_path + 'models',
gen_params=gen_params)
helper.model_name.update(fc1='[%d]' % fc1_size,
im=image_size,
bn=bn,
do=str(dropout_rates).replace(' ', ''),
mon='[%s]' % monitor_metric,
ds='[%s]' % os.path.split(dataset)[1])
print helper.model_name()
# validate all at once
valid_set = ids[ids.set == 'validation']
valid_gen = helper.make_generator(valid_set,
batch_size=len(valid_set),
shuffle=False,
deterministic=False)
for lr in [1e-3, 1e-4, 1e-5]:
helper.load_model()
helper.train(lr=lr, epochs=20, valid_gen=valid_gen)
if batch_size == 1, this code can train successfully, but if I set it more than 1, this code run into error:
Epoch 1/20
Traceback (most recent call last):
File "train_mlsp_narrow_e2e.py", line 81, in <module>
helper.train(lr=lr, epochs=20, valid_gen=valid_gen)
File "/gruntdata/xxx/project/ava-mlsp/kutils/model_helper.py", line 275, in train
use_multiprocessing = params.multiproc)
File "/home/xxx/anaconda3/envs/aesthetic/lib/python2.7/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/home/xxx/anaconda3/envs/aesthetic/lib/python2.7/site-packages/keras/engine/training.py", line 2194, in fit_generator
generator_output = next(output_generator)
File "/home/xxx/anaconda3/envs/aesthetic/lib/python2.7/site-packages/keras/utils/data_utils.py", line 584, in get
six.raise_from(StopIteration(e), e)
File "/home/xxx/.local/lib/python2.7/site-packages/six.py", line 738, in raise_from
raise value
StopIteration: could not broadcast input array from shape (610,409,3) into shape (398,640,3)
does this train end to end code right? And it seems the data reader procesor is uncorrect?
StopIteration: could not broadcast input array from shape (610,409,3) into shape (398,640,3)
As the error message suggests, and because batchSize=1 works, it seems you're feeding images of different sizes in the same batch, which will return an error.