2015-11-17 3 views
21

Se si dispone di due grafici disgiunti, e volete collegarli, trasformando questo:Tensorflow: come sostituire un nodo in un grafico di calcolo?

x = tf.placeholder('float') 
y = f(x) 

y = tf.placeholder('float') 
z = f(y) 

in questo:

x = tf.placeholder('float') 
y = f(x) 
z = g(y) 

C'è un modo per farlo? Sembra che potrebbe rendere la costruzione più facile in alcuni casi.

Ad esempio se si dispone di un grafico che ha l'immagine di ingresso come tf.placeholder e si desidera ottimizzare l'immagine di input, stile di sogno profondo, esiste un modo per sostituire semplicemente il segnaposto con un nodo tf.variable? O devi pensarci prima di costruire il grafico?

risposta

18

TL; DR: Se è possibile definire i due calcoli come funzioni Python, è necessario farlo. Se non ci riesci, ci sono funzionalità più avanzate in TensorFlow per serializzare e importare grafici, che ti consente di comporre grafici da diverse fonti.

Un modo per fare questo in tensorflow è quello di costruire i calcoli disgiunti come separato tf.Graph oggetti, poi convertirli in buffer di protocollo serializzati utilizzando Graph.as_graph_def():

with tf.Graph().as_default() as g_1: 
    input = tf.placeholder(tf.float32, name="input") 
    y = f(input) 
    # NOTE: using identity to get a known name for the output tensor. 
    output = tf.identity(y, name="output") 

gdef_1 = g_1.as_graph_def() 

with tf.Graph().as_default() as g_2: # NOTE: g_2 not g_1  
    input = tf.placeholder(tf.float32, name="input") 
    z = g(input) 
    output = tf.identity(y, name="output") 

gdef_2 = g_2.as_graph_def() 

allora si potrebbe comporre gdef_1 e gdef_2 in un terzo grafico , utilizzando tf.import_graph_def():

with tf.Graph().as_default() as g_combined: 
    x = tf.placeholder(tf.float32, name="") 

    # Import gdef_1, which performs f(x). 
    # "input:0" and "output:0" are the names of tensors in gdef_1. 
    y, = tf.import_graph_def(gdef_1, input_map={"input:0": x}, 
          return_elements=["output:0"]) 

    # Import gdef_2, which performs g(y) 
    z, = tf.import_graph_def(gdef_2, input_map={"input:0": y}, 
          return_elements=["output:0"] 
+0

c'è una ragione che non posso allenare con le risultanti di cui sopra, vale a dire qualcosa del tipo 'tf.train.AdamOptimizer(). Minimizzare (tf.nn.l2_loss (z-x))'? Ottengo qualcosa come "Nessuna variabile da ottimizzare" – bge0

+3

Questo è sfortunatamente corretto. La soluzione è di fare 'Vars = op.outputs [0] per op in tf.get_default_graph(). Get_operations() se op.type == "Variabile"]' 'quindi passare var_list = vars' a' ridurre al minimo() ' . – mrry

+0

Grazie per la rapida risposta! Nel tuo esempio per 'y = f (ingresso)' Ho provato utilizzando un semplice 'tf.mul (w, input)' 'dove w ~ N (0, 0,01)' [cioè a tf.Variabile]. Utilizzando la collezione di variabili io vedo 'W' essere raccolte, ma ancora ottengo questo errore:' TypeError: argomento non è un tf.Variable: Tensor ("import/w: 0", DTYPE = float32_ref) 'mi – bge0

3

Se si desidera combinare modelli addestrati (per esempio di riutilizzare parti di un modello preaddestrato in un nuovo modello), y Puoi usare un Saver per salvare un checkpoint del primo modello, quindi ripristinare quel modello (interamente o parzialmente) in un altro modello.

Per esempio, diciamo che si desidera riutilizzare il modello 1 di pesi w nel modello 2, e anche convertire x da un segnaposto a una variabile:

with tf.Graph().as_default() as g1: 
    x = tf.placeholder('float') 
    w = tf.Variable(1., name="w") 
    y = x * w 
    saver = tf.train.Saver() 

with tf.Session(graph=g1) as sess: 
    w.initializer.run() 
    # train... 
    saver.save(sess, "my_model1.ckpt") 

with tf.Graph().as_default() as g2: 
    x = tf.Variable(2., name="v") 
    w = tf.Variable(0., name="w") 
    z = x + w 
    restorer = tf.train.Saver([w]) # only restore w 

with tf.Session(graph=g2) as sess: 
    x.initializer.run() # x now needs to be initialized 
    restorer.restore(sess, "my_model1.ckpt") # restores w=1 
    print(z.eval()) # prints 3. 
+0

Correggimi se ho torto ma questo approccio non salva la struttura del tuo grafico, quindi devi ridefinire ogni volta che vuoi usare le variabili. – Pietrko

3

Si scopre che tf.train.import_meta_graph passa tutti gli argomenti aggiuntivi per il sottostante import_scoped_meta_graph che ha l'argomento input_map e lo utilizza quando arriva alla propria chiamata (interna) di import_graph_def.

Non è documentato e mi ci è voluto troppo tempo per trovarlo, ma funziona!

+0

La ringrazio molto, ho cercare una tale risposta per molto tempo. – Tom