Why is loss nan after several epochs of model training by QB dataset provided?
Opened this issue · 3 comments
I found some bugs when I reproduced your experiment. Use the qb data set you provided and generate it according to the process After the tfrecord file, in the process of model training, the first few epochs perform normally, but after a few epochs, the training loss becomes nan, which causes the generated model to fail to work on the test set, as shown below:
you can found that the loss the D & G both are nan,which is really confused that i didn't update your code except params
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\compat\v2_compat.py:107: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
../../data/TFRecords/QB_train_64.tfrecords ../../data/TFRecords/QB_test_64.tfrecords ../../data/Output/QB_test_64_psgan
../../data/Output/QB_test_64_psgan
train_tfrecord = ../../data/TFRecords/QB_train_64.tfrecords
test_tfrecord = ../../data/TFRecords/QB_test_64.tfrecords
mode = train
output_dir = ../../data/Output/QB_test_64_psgan
checkpoint = None
max_steps = None
max_epochs = 5
summary_freq = 0
progress_freq = 200
trace_freq = 0
display_freq = 0
save_freq = 1000
batch_size = 4
lr = 0.0001
beta1 = 0.5
l1_weight = 100.0
gan_weight = 1.0
ndf = 32
train_count = 4821
test_count = 81
gpus = 0
blk = 64
Queue-based input pipelines have been replaced by tf.data
. Use tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)
. If shuffle=False
, omit the .shuffle(...)
.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\input.py:262: input_producer (from tensorflow.python.training.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by tf.data
. Use tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)
. If shuffle=False
, omit the .shuffle(...)
.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\input.py:184: limit_epochs (from tensorflow.python.training.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by tf.data
. Use tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)
.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\input.py:192: QueueRunner.init (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the tf.data
module.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\input.py:191: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
parameter_count = 2277536
2022-10-21 21:56:15.110273: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8401
2022-10-21 21:56:16.443692: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
progress epoch 1 step 199 image/sec 18.3 remaining 21m
discrim_loss 0.73236054
gen_loss_GAN 1.3829831
gen_loss_L1 68.04657
progress epoch 1 step 399 image/sec 19.7 remaining 19m
discrim_loss 1.0204148
gen_loss_GAN 1.1590607
gen_loss_L1 37.19431
progress epoch 1 step 599 image/sec 20.4 remaining 17m
discrim_loss 0.8676947
gen_loss_GAN 1.3314455
gen_loss_L1 23.225008
progress epoch 1 step 799 image/sec 20.6 remaining 16m
discrim_loss 0.9771492
gen_loss_GAN 1.5663068
gen_loss_L1 26.68383
progress epoch 1 step 999 image/sec 20.8 remaining 16m
discrim_loss 0.83579296
gen_loss_GAN 1.7441618
gen_loss_L1 25.68571
saving model
progress epoch 1 step 1199 image/sec 20.9 remaining 15m
discrim_loss 0.59473586
gen_loss_GAN 2.3351698
gen_loss_L1 27.571407
progress epoch 2 step 193 image/sec 21.1 remaining 14m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 2 step 393 image/sec 21.3 remaining 13m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 2 step 593 image/sec 21.5 remaining 13m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 2 step 793 image/sec 21.6 remaining 12m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
saving model
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py:1064: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
progress epoch 2 step 993 image/sec 21.7 remaining 11m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 2 step 1193 image/sec 21.8 remaining 11m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 3 step 187 image/sec 21.9 remaining 10m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 3 step 387 image/sec 22.0 remaining 9m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 3 step 587 image/sec 22.1 remaining 9m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
saving model
progress epoch 3 step 787 image/sec 22.1 remaining 8m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 3 step 987 image/sec 22.1 remaining 7m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 3 step 1187 image/sec 22.2 remaining 7m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 4 step 181 image/sec 22.2 remaining 6m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
progress epoch 4 step 381 image/sec 22.3 remaining 6m
discrim_loss nan
gen_loss_GAN nan
gen_loss_L1 nan
saving model
Helllo, I have the save problem as you.Have you saved it?
Helllo, I have the save problem as you.Have you saved it?
sorry for that,i didn't solve this problem,maybe this GAN-based model is hard to train
Helllo, I have the save problem as you.Have you saved it?你好,我也遇到了和你一样的保存问题,你保存了吗?
sorry for that,i didn't solve this problem,maybe this GAN-based model is hard to train抱歉,我没有解决这个问题,也许这个基于 GAN 的模型很难训练
您好,我最近也在看PSgan的代码,作者给出的原始数据,制作数据集的话,是直接按照代码进行切割,还是需要经过滤波矫正之类的操作在进行切割,然后to_patch。希望您能给我解答一下。刚开始学习这方面知识,如有冒犯请见谅!