42

Sto provando ad assegnare un nuovo valore a una variabile tensorflow in python.Come assegnare un valore a una variabile TensorFlow?

import tensorflow as tf 
import numpy as np 

x = tf.Variable(0) 
init = tf.initialize_all_variables() 
sess = tf.InteractiveSession() 
sess.run(init) 

print(x.eval()) 

x.assign(1) 
print(x.eval()) 

Ma l'uscita ottengo è

0 
0 

Quindi il valore non è cambiato. Cosa mi manca?

risposta

72

La dichiarazione x.assign(1) in realtà non assegnare il valore 1-x, ma crea un tf.Operation che si deve esplicitamente corsa per aggiornare la variabile * Una chiamata a Operation.run() o Session.run() può essere utilizzato per eseguire l'operazione.:

assign_op = x.assign(1) 
sess.run(assign_op) # or `assign_op.op.run()` 
print(x.eval()) 
# ==> 1 

(* infatti, esso restituisce un tf.Tensor, corrispondente al valore aggiornato della variabile, per rendere più facile assegnazioni catena.)

+0

Grazie! assign_op.run() restituisce un errore: AttributeError: l'oggetto 'Tensor' non ha attributo 'run'. Ma sess.run (assign_op) funziona perfettamente bene. – abora

+0

In questo esempio, i dati che la 'Variabile'' x' memorizzata in memoria prima dell'operazione 'assign'/tensore mutabile è stata eseguita sovrascritta o è un nuovo tensore creato che memorizza il valore aggiornato? – dannygoldstein

+3

L'attuale implementazione di 'assign()' sovrascrive il valore esistente. – mrry

-4

C'è un approccio più facile:

x = tf.Variable(0) 
x = x + 1 
print x.eval() 
+2

l'o.p. stava esaminando l'uso di 'tf.assign', non di una aggiunta. – vega

6

Prima di tutto è possibile assegnare valori alle variabili/costanti solo alimentando i valori in loro allo stesso modo lo si fa con i segnaposto. Quindi questo è perfettamente legale per fare:

import tensorflow as tf 
x = tf.Variable(0) 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print sess.run(x, feed_dict={x: 3}) 

Per quanto riguarda la vostra confusione con l'operatore tf.assign(). In TF non viene eseguito nulla prima di eseguirlo all'interno della sessione. Quindi devi sempre fare qualcosa di simile a questo: op_name = tf.some_function_that_create_op(params) e quindi all'interno della sessione esegui sess.run(op_name). Usando assegnare a titolo di esempio si farà qualcosa di simile:

import tensorflow as tf 
x = tf.Variable(0) 
y = tf.assign(x, 1) 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print sess.run(x) 
    print sess.run(y) 
    print sess.run(x) 
+1

Nota che l'inserimento del valore tramite 'feed_dict' non assegna quel valore alla variabile. –

2

Inoltre, si deve notare che se si sta utilizzando your_tensor.assign(), poi il tf.global_variables_initializer necessità di non essere chiamato in modo esplicito in quanto l'operazione di assegnazione fa per voi sullo sfondo.

Esempio:

In [212]: w = tf.Variable(12) 
In [213]: w_new = w.assign(34) 

In [214]: with tf.Session() as sess: 
    ...:  sess.run(w_new) 
    ...:  print(w_new.eval()) 

# output 
34 

Tuttavia, questo non inizializzare tutte le variabili, ma sarà inizializzare solo la variabile su cui è stato eseguito su assign.

2

È anche possibile assegnare un nuovo valore a tf.Variable senza aggiungere un'operazione al grafico: tf.Variable.load(value, session). Questa funzione può anche risparmiare l'aggiunta di segnaposto quando si assegna un valore dall'esterno del grafico ed è utile nel caso in cui il grafico sia finalizzato.

import tensorflow as tf 
x = tf.Variable(0) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print(sess.run(x)) # Prints 0. 
x.load(1, sess) 
print(sess.run(x)) # Prints 1.