前言
TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?
生成模型
主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。
freeze_graph
这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:
with tf.Session() as sess: saver = tf.train.Saver() saver.save(session, "model.ckpt") tf.train.write_graph(session.graph_def, '', 'graph.pb')
然后使用TensorFlow源码中的freeze_graph工具进行固化操作:
首先需要build freeze_graph 工具( 需要 bazel ):
bazel build tensorflow/python/tools:freeze_graph
然后使用这个工具进行固化(/path/to/表示文件路径):
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants
其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。
牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。
首先我们需要引入这个方法
from tensorflow.python.framework.graph_util import convert_variables_to_constants
在想要保存的地方加入如下代码,把变量转换成常量
output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])
这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)
with tf.name_scope('output'): w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN])) tf.summary.histogram('output/weight', w_out) b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN])) tf.summary.histogram('output/biases', b_out) out = tf.add(tf.matmul(dense2, w_out), b_out) out = tf.nn.softmax(out) predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')
由于我们采用了name_scope所以我们在predict之前需要加上output/
生成文件
with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。
运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。
测试模型
在Python环境下,我们首先需要加载这个模型,代码如下:
with open('./model/rounded_graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'inputs/X:0': newInput_X}, return_elements=['output/predict:0'])
由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:
newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")
在input_map的参数填入新的placeholder。
在调用我们的网络的时候直接用这个新的placeholder接收数据,如:
text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})
然后就是运行我们的网络,看是否可以运行吧。
以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
《魔兽世界》大逃杀!60人新游玩模式《强袭风暴》3月21日上线
暴雪近日发布了《魔兽世界》10.2.6 更新内容,新游玩模式《强袭风暴》即将于3月21 日在亚服上线,届时玩家将前往阿拉希高地展开一场 60 人大逃杀对战。
艾泽拉斯的冒险者已经征服了艾泽拉斯的大地及遥远的彼岸。他们在对抗世界上最致命的敌人时展现出过人的手腕,并且成功阻止终结宇宙等级的威胁。当他们在为即将于《魔兽世界》资料片《地心之战》中来袭的萨拉塔斯势力做战斗准备时,他们还需要在熟悉的阿拉希高地面对一个全新的敌人──那就是彼此。在《巨龙崛起》10.2.6 更新的《强袭风暴》中,玩家将会进入一个全新的海盗主题大逃杀式限时活动,其中包含极高的风险和史诗级的奖励。
《强袭风暴》不是普通的战场,作为一个独立于主游戏之外的活动,玩家可以用大逃杀的风格来体验《魔兽世界》,不分职业、不分装备(除了你在赛局中捡到的),光是技巧和战略的强弱之分就能决定出谁才是能坚持到最后的赢家。本次活动将会开放单人和双人模式,玩家在加入海盗主题的预赛大厅区域前,可以从强袭风暴角色画面新增好友。游玩游戏将可以累计名望轨迹,《巨龙崛起》和《魔兽世界:巫妖王之怒 经典版》的玩家都可以获得奖励。
更新日志
- 中国武警男声合唱团《辉煌之声1天路》[DTS-WAV分轨]
- 紫薇《旧曲新韵》[320K/MP3][175.29MB]
- 紫薇《旧曲新韵》[FLAC/分轨][550.18MB]
- 周深《反深代词》[先听版][320K/MP3][72.71MB]
- 李佳薇.2024-会发光的【黑籁音乐】【FLAC分轨】
- 后弦.2012-很有爱【天浩盛世】【WAV+CUE】
- 林俊吉.2012-将你惜命命【美华】【WAV+CUE】
- 晓雅《分享》DTS-WAV
- 黑鸭子2008-飞歌[首版][WAV+CUE]
- 黄乙玲1989-水泼落地难收回[日本天龙版][WAV+CUE]
- 周深《反深代词》[先听版][FLAC/分轨][310.97MB]
- 姜育恒1984《什么时候·串起又散落》台湾复刻版[WAV+CUE][1G]
- 那英《如今》引进版[WAV+CUE][1G]
- 蔡幸娟.1991-真的让我爱你吗【飞碟】【WAV+CUE】
- 群星.2024-好团圆电视剧原声带【TME】【FLAC分轨】