首先定义一个tf.train.Saver类:

saver = tf.train.Saver(max_to_keep=1)

其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后5个模型,如果设置成0,训练过程中的所有模型都会被保存。

SRE实战 互联网时代守护先锋,助力企业售后服务体系运筹帷幄!一键直达领取阿里云限量特价优惠。

模型训练好以后,保存模型:

saver.save(sess, ckpt_dir + "/nn_model.ckpt", global_step=1)

其中,sess是Session,ckpt_dir + "/nn_model.ckpt"是保存的路径和名称,global_step是模型名称的后缀名,由于我们只保存最后一个模型,所以可以设置为1,如果每一个模型都想保存,可以设置成训练的epoch。

载入模型比较简单:

saver.restore(sess, model_file)

其中,sess是Session,model_file是模型的路径和名称。

扫码关注我们
微信号:SRE实战
拒绝背锅 运筹帷幄