一、TensorFlow常規模型加載方法
保存模型
tf.train.Saver()類,.save(sess, ckpt文件目錄)方法
參數名稱 | 功能說明 | 默認值 |
var_list | Saver中存儲變量集合 | 全局變量集合 |
reshape | 加載時是否恢復變量形狀 | True |
sharded | 是否將變量輪循放在所有設備上 | True |
max_to_keep | 保留最近檢查點個數 | 5 |
restore_sequentially | 是否按順序恢復變量,模型較大時順序恢復內存消耗小 | True |
var_list是字典形式{變量名字符串: 變量符號},相對應的restore也根據同樣形式的字典將ckpt中的字符串對應的變量加載給程序中的符號。
如果Saver給定了字典作為加載方式,則按照字典來,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否則每個變量尋找自己的name屬性在ckpt中的對應值進行加載。
加載模型
當我們基于checkpoint文件(ckpt)加載參數時,實際上我們使用Saver.restore取代了initializer的初始化
checkpoint文件會記錄保存信息,通過它可以定位最新保存的模型:
ckpt = tf.train.get_checkpoint_state('./model/')print(ckpt.model_checkpoint_path)
.meta文件保存了當前圖結構
.index文件保存了當前參數名
.data文件保存了當前參數值
tf.train.import_meta_graph函數給出model.ckpt-n.meta的路徑后會加載圖結構,并返回saver對象
ckpt = tf.train.get_checkpoint_state('./model/')
tf.train.Saver函數會返回加載默認圖的saver對象,saver對象初始化時可以指定變量映射方式,根據名字映射變量(『TensorFlow』滑動平均)
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
saver.restore函數給出model.ckpt-n的路徑后會自動尋找參數名-值文件進行加載
saver.restore(sess,'./model/model.ckpt-0')saver.restore(sess,ckpt.model_checkpoint_path)
1.不加載圖結構,只加載參數
由于實際上我們參數保存的都是Variable變量的值,所以其他的參數值(例如batch_size)等,我們在restore時可能希望修改,但是圖結構在train時一般就已經確定了,所以我們可以使用tf.Graph().as_default()新建一個默認圖(建議使用上下文環境),利用這個新圖修改和變量無關的參值大小,從而達到目的。
新聞熱點
疑難解答