我试图从上一次训练中恢复,但我能够保存模型,但无法恢复它。我有以下代码,它运行时没有错误。我知道它不能恢复它,因为当我重新开始训练时,损失值又回到了很大的值。
有什么帮助吗?
ckpt_path = os.path.abspath(os.path.dirname(__file__)) + '/weights/'
labels_net, loss = vgg16(crop_size)
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
saver = tf.train.Saver(max_to_keep=3)
# Train
with tf.Session() as sess:
# Load previous weights
if os.listdir(ckpt_path) ==[]:
sess.run(tf.global_variables_initializer())
else:
for file in os.listdir(ckpt_path):
if 'vgg16' in file:
try:
saver = tf.train.import_meta_graph(os.path.join(ckpt_path+file))
saver.restore(sess, ckpt_path+'vgg16-2')
print('Resuming training....')
except:
sess.run(tf.global_variables_initializer())
else:
sess.run(tf.global_variables_initializer())
print('Epoch', 'Training loss')
for epoch_i in range(epochs):
for batch_i in range(batches):
batch_crops = getBatch(crops_train, batch_i, batch_size)
batch_labels = getBatch(labels_train, batch_i, batch_size)
x = sess.graph.get_tensor_by_name('x:0')
y = sess.graph.get_tensor_by_name('y:0')
sess.run(optimizer, feed_dict={x: batch_crops, y: batch_labels})#, options=run_options, run_metadata=run_metadata)
train_loss = sess.run(loss, feed_dict={x: batch_crops, y: batch_labels})
print(epoch_i+1, train_loss)
saver.save(sess, ckpt_path+'vgg16', global_step=2)发布于 2019-11-08 07:21:53
我对张量流了解不多,但是。我认为您加载的文件与您保存的文件不同。
您的加载线是saver.restore(sess, ckpt_path+'vgg16-2')
因此,您要保存到vgg16并从vgg16-2加载
https://stackoverflow.com/questions/58758127
复制相似问题