2015-12-28 17 views
8

Sto cercando di implementare un suggerimento di risposte: Tensorflow: how to save/restore a model?tensorflow: salvataggio e il ripristino della sessione

Ho un oggetto che avvolge un modello tensorflow in stile sklearn.

import tensorflow as tf 
class tflasso(): 
    saver = tf.train.Saver() 
    def __init__(self, 
       learning_rate = 2e-2, 
       training_epochs = 5000, 
        display_step = 50, 
        BATCH_SIZE = 100, 
        ALPHA = 1e-5, 
        checkpoint_dir = "./", 
      ): 
     ... 

    def _create_network(self): 
     ... 


    def _load_(self, sess, checkpoint_dir = None): 
     if checkpoint_dir: 
      self.checkpoint_dir = checkpoint_dir 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      self.saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

    def fit(self, train_X, train_Y , load = True): 
     self.X = train_X 
     self.xlen = train_X.shape[1] 
     # n_samples = y.shape[0] 

     self._create_network() 
     tot_loss = self._create_loss() 
     optimizer = tf.train.AdagradOptimizer(self.learning_rate).minimize(tot_loss) 

     # Initializing the variables 
     init = tf.initialize_all_variables() 
     " training per se" 
     getb = batchgen(self.BATCH_SIZE) 

     yvar = train_Y.var() 
     print(yvar) 
     # Launch the graph 
     NUM_CORES = 3 # Choose how many cores to use. 
     sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES, 
                  intra_op_parallelism_threads=NUM_CORES) 
     with tf.Session(config= sess_config) as sess: 
      sess.run(init) 
      if load: 
       self._load_(sess) 
      # Fit all training data 
      for epoch in range(self.training_epochs): 
       for (_x_, _y_) in getb(train_X, train_Y): 
        _y_ = np.reshape(_y_, [-1, 1]) 
        sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_}) 
       # Display logs per epoch step 
       if (1+epoch) % self.display_step == 0: 
        cost = sess.run(tot_loss, 
          feed_dict={ self.vars.xx: train_X, 
            self.vars.yy: np.reshape(train_Y, [-1, 1])}) 
        rsq = 1 - cost/yvar 
        logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq) 
        print(logstr) 
        self.saver.save(sess, self.checkpoint_dir + 'model.ckpt', 
         global_step= 1+ epoch) 

      print("Optimization Finished!") 
     return self 

quando ho eseguito:

tfl = tflasso() 
tfl.fit(train_X, train_Y , load = False) 

ottengo uscita:

Epoch: 50 cost = 38.4705 R^2 = -1.2036 
    b1: 0.118122 
Epoch: 100 cost = 26.4506 R^2 = -0.5151 
    b1: 0.133597 
Epoch: 150 cost = 22.4330 R^2 = -0.2850 
    b1: 0.142261 
Epoch: 200 cost = 20.0361 R^2 = -0.1477 
    b1: 0.147998 

Tuttavia, quando si tenta di recuperare i parametri (anche senza uccidere l'oggetto): tfl.fit(train_X, train_Y , load = True)

Ottengo risultati strani. Innanzitutto, il valore caricato non corrisponde a quello salvato.

loading a session 
loaded b1: 0.1   <------- Loaded another value than saved 
Epoch: 50 cost = 30.8483 R^2 = -0.7670 
    b1: 0.137484 

Qual è il modo corretto per caricare e probabilmente prima ispezionare le variabili salvate?

+0

documentazione tensorflow è privo di esempi piuttosto semplice, è necessario scavare nelle cartelle esempi e dare un senso di esso per lo più da soli – diffeomorphism

risposta

10

TL; DR: Si dovrebbe cercare di rielaborare questa classe in modo che self.create_network() è chiamato (i) una sola volta, e (ii) prima che il tf.train.Saver() è costruito.

Qui ci sono due problemi, dovuti alla struttura del codice e al comportamento predefinito dello tf.train.Saver constructor. Quando costruisci un risparmiatore senza argomenti (come nel tuo codice), raccoglie il set corrente di variabili nel tuo programma e aggiunge operazioni al grafico per salvarle e ripristinarle. Nel tuo codice, quando chiami tflasso(), costruirà un risparmiatore e non ci saranno variabili (perché create_network() non è stato ancora chiamato). Di conseguenza, il checkpoint dovrebbe essere vuoto.

Il secondo problema è che — per impostazione predefinita — il formato di un checkpoint salvato è una mappa da name property of a variable al valore corrente. Se si creano due variabili con lo stesso nome, saranno automaticamente "uniquified" di tensorflow:

v = tf.Variable(..., name="weights") 
assert v.name == "weights" 
w = tf.Variable(..., name="weights") 
assert v.name == "weights_1" # The "_1" is added by TensorFlow. 

La conseguenza di ciò è che, quando si chiama self.create_network() nella seconda chiamata a tfl.fit(), le variabili avranno tutti nomi diversi dai nomi che sono memorizzati nel punto di controllo — o sarebbero stati se il risparmiatore fosse stato costruito dopo la rete. (È possibile evitare questo comportamento passando un dizionario nome- Variable al costruttore saver, ma questo è di solito abbastanza imbarazzante.)

Esistono due soluzioni principali:

  1. In ogni chiamata a tflasso.fit(), creare l'intero modello di nuovo, definendo un nuovo tf.Graph, quindi in quel grafico la creazione della rete e la creazione di un tf.train.Saver.

  2. RACCOMANDATO Creare la rete, poi l'tf.train.Saver nel costruttore tflasso, e riutilizzare questo grafico su ogni chiamata a tflasso.fit().Si noti che potrebbe essere necessario fare un po 'più di lavoro per riorganizzare le cose (in particolare, non sono sicuro di ciò che si fa con self.X e self.xlen) ma dovrebbe essere possibile ottenere ciò con placeholders e l'alimentazione.

+0

grazie! Il 'xlen' è usato all'interno di' self._create_network() 'per impostare la dimensione di input di' X' (placeholder init: 'self.vars.xx = tf.placeholder (" float ", shape = [None, self.xlen ]) '). Da quello che dici, il modo preferito è passare 'xlen' all'inizializzatore. –

+0

C'è un modo per reimpostare univocatore/cancellare vecchie variabili tf dopo la reinizializzazione dell'oggetto? –

+1

Per fare ciò è necessario creare un nuovo 'tf.Graph' e impostarlo come predefinito prima di (i) creare la rete e (ii) creare un' Saver'. Se si avvolge il corpo di 'tflasso.fit()' in a 'con tf.Graph(). As_default():' block, e si sposta la costruzione 'Saver' all'interno di quel blocco, i nomi dovrebbero essere gli stessi ogni volta che si chiama 'fit()'. – mrry