重庆分公司,新征程启航
为企业提供网站建设、域名注册、服务器等服务
使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
我对这篇文章进行了整理和汇总。
首先是模型的保存。直接上代码:
#!/usr/bin/env python #-*- coding:utf-8 -*- ############################ #File Name: tut1_save.py #Author: Wang #Mail: wang19920419@hotmail.com #Created Time:2017-08-30 11:04:25 ############################ import tensorflow as tf # prepare to feed input, i.e. feed_dict and placeholders w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') b1 = tf.Variable(2.0, name = 'bias1') feed_dict = {w1:[10,3], w2:[5,5]} # define a test operation that will be restored w3 = tf.add(w1, w2) # without name, w3 will not be stored w4 = tf.multiply(w3, b1, name = "op_to_restore") #saver = tf.train.Saver() saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) sess = tf.Session() sess.run(tf.global_variables_initializer()) print sess.run(w4, feed_dict) #saver.save(sess, 'my_test_model', global_step = 100) saver.save(sess, 'my_test_model') #saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)
另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。