2016-05-10 88 views
5

Se tento di importare una definizione TensorFlow grafico salvato conCome posso ottenere di tensorflow 'import_graph_def' per tornare tensori

import tensorflow as tf 
from tensorflow.python.platform import gfile 

with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
x, y, y_ = tf.import_graph_def(graph_def, 
           return_elements=['data/inputs', 
               'output/network_activation', 
               'data/correct_outputs'], 
           name='') 

i valori restituiti non sono Tensor s come previsto, ma qualcosa di diverso: invece, ad esempio, , di ottenere x come

Tensor("data/inputs:0", shape=(?, 784), dtype=float32) 

ottengo

name: "data/inputs_1" 
op: "Placeholder" 
attr { 
    key: "dtype" 
    value { 
    type: DT_FLOAT 
    } 
} 
attr { 
    key: "shape" 
    value { 
    shape { 
    } 
    } 
} 

Cioè, invece di ottenere il tensore atteso x ottengo, x.op. Questo mi confonde perché il documentation sembra dire che dovrei ottenere uno Tensor (anche se ci sono un po 'di o s che rendono difficile capire).

Come posso ottenere tf.import_graph_def per restituire specifici Tensor s che posso quindi utilizzare (ad esempio in alimentazione del modello caricato o analisi in esecuzione)?

+0

La seconda riga di codice dovrebbe essere 'da tensorflow.python.platform import gfile'. – tobe

risposta

3

I nomi 'data/inputs', 'output/network_activation' e 'data/correct_outputs' sono in realtà nomi di operazioni. Per arrivare tf.import_graph_def() per tornare tf.Tensor oggetti, si dovrebbe aggiungere l'indice di un'uscita per il nome dell'operazione, che in genere è ':0' per le operazioni single-uscita:

x, y, y_ = tf.import_graph_def(graph_def, 
           return_elements=['data/inputs:0', 
               'output/network_activation:0', 
               'data/correct_outputs:0'], 
           name='')