Qualcuno potrebbe chiarire se lo stato iniziale dell'RNN in TF viene ripristinato per i mini-lotti successivi, oppure viene utilizzato l'ultimo stato del mini-lotto precedente come indicato in Ilya Sutskever et al., ICLR 2015?È stato ripristinato lo stato iniziale RNN per i mini-lotti successivi?
risposta
Le operazioni tf.nn.dynamic_rnn()
o tf.nn.rnn()
consentono di specificare lo stato iniziale dell'RNN utilizzando il parametro initial_state
. Se non si specifica questo parametro, gli stati nascosti verranno inizializzati su zero vettori all'inizio di ciascun batch di addestramento.
In TensorFlow, è possibile avvolgere i tensori in tf.Variable()
per mantenere i loro valori nel grafico tra più sessioni di sessione. Assicurati di contrassegnarli come non addestrabili perché gli ottimizzatori regolano tutte le variabili trainabili per impostazione predefinita.
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)
with tf.control_dependencies([state.assign(new_state)]):
output = tf.identity(output)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})
Non ho testato questo codice ma dovrebbe darvi un suggerimento nella giusta direzione. C'è anche un tf.nn.state_saving_rnn()
a cui puoi fornire un oggetto salvaschermo, ma non l'ho ancora usato.
Oltre alla risposta di danijar, ecco il codice per un LSTM, il cui stato è una tupla (state_is_tuple=True
). Supporta anche più livelli.
Definiamo due funzioni: una per ottenere le variabili di stato con uno stato iniziale zero e una funzione per restituire un'operazione, che possiamo passare a session.run
per aggiornare le variabili di stato con l'ultimo stato nascosto dell'LSTM. risposta
def get_state_variables(batch_size, cell):
# For each layer, get the initial state and make a variable out of it
# to enable updating its value.
state_variables = []
for state_c, state_h in cell.zero_state(batch_size, tf.float32):
state_variables.append(tf.contrib.rnn.LSTMStateTuple(
tf.Variable(state_c, trainable=False),
tf.Variable(state_h, trainable=False)))
# Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
return tuple(state_variables)
def get_state_update_op(state_variables, new_states):
# Add an operation to update the train states with the last state tensors
update_ops = []
for state_variable, new_state in zip(state_variables, new_states):
# Assign the new state to the state variables on this layer
update_ops.extend([state_variable[0].assign(new_state[0]),
state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
return tf.tuple(update_ops)
simili di danijar, possiamo usare che per aggiornare lo stato del LSTM dopo ogni batch:
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell_layer] * num_layers)
# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)
# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})
La differenza principale è che state_is_tuple=True
rende lo stato del LSTM un LSTMStateTuple contenente due variabili (stato cellule e stato nascosto) invece di una singola variabile. L'utilizzo di più livelli rende quindi lo stato di LSTM una tupla di LSTMStateTuples, uno per livello.
Nota come lo fai tu crei num_layers _identical_ cells che non è quello che vuoi fare probabilmente –