当前位置:网站首页>加载 已训练模型 张量的 几种方法
加载 已训练模型 张量的 几种方法
2022-08-08 22:47:00 【Sarah ฅʕ•̫͡•ʔฅ】
1、saver.restore()
用 saver.restore()加载 模型 之前,首先要 定义 模型 的 计算图,具体操作如下:
#重新定义 计算图
def define_graph(input):
graph_define
return computed_tensor #返回需要计算的张量
with tf.Session() as sess:
#input = tf.placeholder(tf.float,shape,name="input")
computed_tensor = define_graph(input)
#加载模型
saver = tf.train.Saver()
saver.restore(sess, model.ckpt)
#计算 张量
sess.run(computed_tensor, feed_dict={
input:input_data})
2、saver.restore() + tf.train.import_meta_graph()
#利用 tf.train.import_meta_graph()载入计算图
saver = tf.train.import_meta_graph('model.ckpt.meta')
with tf.Session() as sess:
#载入模型
saver.restore(sess,'model.ckpt')
#载入 要计算的张量名
input, computed_tensor = tf.get_default_graph().get_tensor_by_name(['input', 'computed_tensor:0'])
# 计算张量
sess.run(computed_tensor, feed_dict = {
input:input_data})
3、gfile.GFile()
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
#保存模型,及要计算的 tensor
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['input', 'computed_tensor'])
with gfile.GFile('model.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
#加载模型,以及要计算的 tensor
with tf.Session() as sess:
with gfile.FastGFile('model.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
input, computed_tensor = tf.import_graph_def(graph_def,return_elements=["input", "computed_tensor:0"])
sess.run(computed_tensor, feed_dict = {
input:input_data})
边栏推荐
猜你喜欢
随机推荐
The Socket (Socket)
机器学习建模高级用法!构建企业级AI建模流水线
Liquor Daily Question ---- Find the nth Fibonacci number
关于OD的bp send断点 常用断点(OD)
You know you every day in the use of NAT?
MySQL8.0 及 SQL 注入
6.8.3 sigqueue函数
JS中的原型与原型链
JSDay2- 长度最小的子数组
JS中数组扁平化的几种方法
选择排序
JS中的作用域与作用域链
数组去重的几种方法
The concept of GIL and pools
Upload-labs Pass-05
Mysql数据库身份证统计sql数据库加密等操作
iptables防火墙内容全解
浅析WLAN——无线局域网
Kubernetes 资源编排系列之二: Helm 篇
三国战绩 风云再起 网络版 物品序号 和 基址列表