tensorflow加载训练好的模型及参数(读取checkpoint)

Ines ·
更新时间:2024-09-21
· 971 次阅读

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型
在这里插入图片描述

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf graph=tf.get_default_graph() # 获取当前图 sess=tf.Session() sess.run(tf.global_variables_initializer()) checkpoint_file=tf.train.latest_checkpoint(model_path) saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess,checkpoint_file) 2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf graph=tf.get_default_graph() # 获取当前图 sess=tf.Session() sess.run(tf.global_variables_initializer()) checkpoint_file=os.path.join(model_path,'model-9200') saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess,checkpoint_file) 获取变量值 ## 得到当前图中所有变量的名称 tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] # 查看所有变量 print(tensor_name_list) # 获取input_x和input_y的变量值 input_x = graph.get_operation_by_name("input_x").outputs[0] input_y = graph.get_operation_by_name("input_y").outputs[0]
作者:fo系研究僧



checkpoint 训练 参数 模型 tensorflow

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章