2016-06-27 40 views
8

Ho iniziato di recente a lavorare con tensorflow, quindi sono ancora alle prese con le basi.Come prevedere una sequenza semplice usando seq2seq da tensorflow?

Volevo creare una semplice previsione seq2seq.

  • ingresso è lista di numeri tra 0 e 1.
  • uscita è primo numero elenco e il resto dei numeri moltiplicato per primo.

Sono riuscito a valutare le prestazioni del modello e ottimizzare i pesi. La cosa che ho dovuto affrontare è come fare previsioni con un modello addestrato.

model_outputs, states = seq2seq.basic_rnn_seq2seq(encoder_inputs, 
                decoder_inputs, 
                rnn_cell.BasicLSTMCell(data_point_dim, state_is_tuple=True)) 

Per generare model_outputs mi servono sia valori di uscita di ingresso e per il modello, che è buono per la valutazione, ma in previsione Ho solo valori di ingresso. Immagino di dover fare qualcosa con gli stati, ma non sono sicuro di come trasformarli in una sequenza di galleggianti.

codice completo è disponibile qui https://gist.github.com/anonymous/be405097927758acca158666854600a2

risposta

4

Quando ti alleni, si dà l'ingresso decoder ad ogni passo temporale decoder come l'output desiderato. Durante il test, non si ha l'output desiderato, quindi il meglio che si può fare è campionare un'uscita. Questo sarà l'input per il prossimo timestep.

TLDR; Alimentare l'uscita del decoder ad ogni timestep come input per il timestep successivo.

Edit: codici Alcuni TF

Il basic_rnn_seq2seq funzione ritorno s rnn_decoder (decoder_inputs, enc_states [-1], cellulare)

diamo un'occhiata al rnn_decoder: def rnn_decoder (decoder_inputs, initial_state, cell, loop_function = Nessuno, scope = n): ....

loop_function: se non None, questa funzione sarà applicata a i-esima uscita per generare i + 1-esimo ingresso e decoder_inputs verrà ignorata , eccetto per il primo elemento (simbolo "GO"). Questo può essere usato per la decodifica, ma anche per la formazione per emulare http://arxiv.org/pdf/1506.03099v2.pdf.

durante la decodifica, è necessario impostare questo loop_function = Vero

raccomando guardando il file translate.py nella biblioteca seq2seq tensorflow per vedere come questo viene gestito.

+0

se ho capito correttamente andando da quella soluzione, ignorerei la variabile stati e basta usare l'output in session.run (output, feed_dict = feed) per ottenere risultati? Non dovrebbe esserci un modo per utilizzare gli stati in questo processo? Ho lavorato con scikit-learn e speravo che ci fosse un modo per creare qualcosa come il metodo model.predict –

+0

Aggiunti ulteriori chiarimenti ma sì, ignorerai lo stato nascosto perché gli stati nascosti sono esattamente così: sono nascosti stati che sono già stati utilizzati per creare comunque l'output finale del RNN. Per quanto riguarda il cambio di codice, consiglio di dare un'occhiata all'esempio di Seq2seq wmt fornito da Tensorflow. Credo che tu debba modificare il feed di output. – user4383691

+0

Sto anche cercando di far funzionare qualcosa del genere. Qualcuno potrebbe fornire un semplice esempio di lavoro che fa ciò che è descritto in questa risposta? –

0

La precedente risposta dell'utente4383691 è incompleta. Ho lo stesso problema, e dopo aver scavato nella rnn_decoder, trovato questo: Il modello applica il loop_fn all'uscita esimo, così Vero non ha senso in quanto non è una funzione. Dovresti creare una funzione che possa contenere l'output ith e restituire l'output di i + 1 th. Sono ancora in procinto di fare una tale funzione, e aggiornerò non appena è fatto. sguardo

0

Let al source code:

prev = None for i, inp in enumerate(decoder_inputs): 
    if loop_function is not None and prev is not None: 
     with variable_scope.variable_scope("loop_function", reuse=True): 
     inp = loop_function(prev, i) 
    if i > 0: 
     variable_scope.get_variable_scope().reuse_variables() 
    output, state = cell(inp, state) 
    outputs.append(output) 
    if loop_function is not None: 
     prev = output 

Il ciclo enumera i decoder_inputs se ci si allena con i decoder_inputs forniti o test senza gli ingressi. È perché il decoder_inputs viene sostituito dall'output della funzione loop_ (nella quarta riga dello snippet sopra) durante il test.

In genere è possibile compilare dec_inputs con end_ids come here.

while len(dec_inputs) < self._hps.dec_timesteps: 
    dec_inputs.append(end_id) 
    while len(targets) < self._hps.dec_timesteps: 
    targets.append(end_id)