当前位置:网站首页>加载 已训练模型 张量的 几种方法
加载 已训练模型 张量的 几种方法
2022-08-08 22:47:00 【Sarah ฅʕ•̫͡•ʔฅ】
1、saver.restore()
用 saver.restore()加载 模型 之前,首先要 定义 模型 的 计算图,具体操作如下:
#重新定义 计算图
def define_graph(input):
graph_define
return computed_tensor #返回需要计算的张量
with tf.Session() as sess:
#input = tf.placeholder(tf.float,shape,name="input")
computed_tensor = define_graph(input)
#加载模型
saver = tf.train.Saver()
saver.restore(sess, model.ckpt)
#计算 张量
sess.run(computed_tensor, feed_dict={
input:input_data})
2、saver.restore() + tf.train.import_meta_graph()
#利用 tf.train.import_meta_graph()载入计算图
saver = tf.train.import_meta_graph('model.ckpt.meta')
with tf.Session() as sess:
#载入模型
saver.restore(sess,'model.ckpt')
#载入 要计算的张量名
input, computed_tensor = tf.get_default_graph().get_tensor_by_name(['input', 'computed_tensor:0'])
# 计算张量
sess.run(computed_tensor, feed_dict = {
input:input_data})
3、gfile.GFile()
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
#保存模型,及要计算的 tensor
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['input', 'computed_tensor'])
with gfile.GFile('model.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
#加载模型,以及要计算的 tensor
with tf.Session() as sess:
with gfile.FastGFile('model.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
input, computed_tensor = tf.import_graph_def(graph_def,return_elements=["input", "computed_tensor:0"])
sess.run(computed_tensor, feed_dict = {
input:input_data})
边栏推荐
猜你喜欢
随机推荐
SaaS启动阶段增长指南(上)
CTF攻防世界
Kubernetes与OpenStack
Analysis of WLAN - Wireless Local Area Network
Kubernetes与OpenStack
DHCP's defense mechanism - DHCP Snooping (DHCP snooping)
Taro小程序跨端开发入门实战
thinkphp5 if else的表达式怎么写?
如何实现call、apply、bind
laravel6框架跨域请求利器之 Laravel CORS 扩展包的安装和使用
Sql注入以及靶场演示
flutter 基本类写法
Upload-labs Pass-02(MIME验证)
Button Wizard for ts API usage
JSDay2- 长度最小的子数组
ArcPy要素批量转dwg
Unity 双生ScrollView滑动冲突问题
Shell脚本学习笔记
雷电模拟器frida脱壳
ArcPy设置全库唯一标识码