TL; DR: Se è possibile definire i due calcoli come funzioni Python, è necessario farlo. Se non ci riesci, ci sono funzionalità più avanzate in TensorFlow per serializzare e importare grafici, che ti consente di comporre grafici da diverse fonti.
Un modo per fare questo in tensorflow è quello di costruire i calcoli disgiunti come separato tf.Graph
oggetti, poi convertirli in buffer di protocollo serializzati utilizzando Graph.as_graph_def()
:
with tf.Graph().as_default() as g_1:
input = tf.placeholder(tf.float32, name="input")
y = f(input)
# NOTE: using identity to get a known name for the output tensor.
output = tf.identity(y, name="output")
gdef_1 = g_1.as_graph_def()
with tf.Graph().as_default() as g_2: # NOTE: g_2 not g_1
input = tf.placeholder(tf.float32, name="input")
z = g(input)
output = tf.identity(y, name="output")
gdef_2 = g_2.as_graph_def()
allora si potrebbe comporre gdef_1
e gdef_2
in un terzo grafico , utilizzando tf.import_graph_def()
:
with tf.Graph().as_default() as g_combined:
x = tf.placeholder(tf.float32, name="")
# Import gdef_1, which performs f(x).
# "input:0" and "output:0" are the names of tensors in gdef_1.
y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
return_elements=["output:0"])
# Import gdef_2, which performs g(y)
z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
return_elements=["output:0"]
c'è una ragione che non posso allenare con le risultanti di cui sopra, vale a dire qualcosa del tipo 'tf.train.AdamOptimizer(). Minimizzare (tf.nn.l2_loss (z-x))'? Ottengo qualcosa come "Nessuna variabile da ottimizzare" – bge0
Questo è sfortunatamente corretto. La soluzione è di fare 'Vars = op.outputs [0] per op in tf.get_default_graph(). Get_operations() se op.type == "Variabile"]' 'quindi passare var_list = vars' a' ridurre al minimo() ' . – mrry
Grazie per la rapida risposta! Nel tuo esempio per 'y = f (ingresso)' Ho provato utilizzando un semplice 'tf.mul (w, input)' 'dove w ~ N (0, 0,01)' [cioè a tf.Variabile]. Utilizzando la collezione di variabili io vedo 'W' essere raccolte, ma ancora ottengo questo errore:' TypeError: argomento non è un tf.Variable: Tensor ("import/w: 0", DTYPE = float32_ref) 'mi – bge0