2015-11-19 8 views
8

Ecco il codice che sto cercando di correre-Come cambio il dtype in TensorFlow per un file csv?

import tensorflow as tf 
import numpy as np 
import input_data 

filename_queue = tf.train.string_input_producer(["cs-training.csv"]) 

reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 

record_defaults = [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11 = tf.decode_csv(
    value, record_defaults=record_defaults) 
features = tf.concat(0, [col2, col3, col4, col5, col6, col7, col8, col9, col10, col11]) 

with tf.Session() as sess: 
    # Start populating the filename queue. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(1200): 
    # Retrieve a single instance: 
    print i 
    example, label = sess.run([features, col1]) 
    try: 
     print example, label 
    except: 
     pass 

    coord.request_stop() 
    coord.join(threads) 

questo codice di ritorno l'errore sotto.

--------------------------------------------------------------------------- 
InvalidArgumentError      Traceback (most recent call last) 
<ipython-input-23-e42fe2609a15> in <module>() 
     7  # Retrieve a single instance: 
     8  print i 
----> 9  example, label = sess.run([features, col1]) 
    10  try: 
    11   print example, label 

/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict) 
    343 
    344  # Run request and get response. 
--> 345  results = self._do_run(target_list, unique_fetch_targets, feed_dict_string) 
    346 
    347  # User may have fetched the same tensor multiple times, but we 

/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, target_list, fetch_list, feed_dict) 
    417   # pylint: disable=protected-access 
    418   raise errors._make_specific_exception(node_def, op, e.error_message, 
--> 419            e.code) 
    420   # pylint: enable=protected-access 
    421  raise e_type, e_value, e_traceback 

InvalidArgumentError: Field 1 in record 0 is not a valid int32: 0.766126609 

Ha un sacco di informazioni che lo seguono e che penso sia irrilevante per il problema. Ovviamente il problema è che molti dei dati che sto alimentando al programma non sono del dtype int32. Sono per lo più numeri fluttuanti. Ho provato alcune cose per cambiare il dtype come impostare esplicitamente l'argomento dtype=float in tf.decode_csv e tf.concat. Nessuno dei due ha funzionato. È un argomento non valido. Per finire, non so se questo codice farà effettivamente una previsione sui dati. Voglio pronosticare se col1 sarà un 1 o uno 0 e non vedo nulla nel codice che possa suggerire che realizzerà quella previsione. Forse salverò la domanda per un thread diverso. Qualsiasi aiuto è molto apprezzato!

risposta

1

La risposta a cambiare il DTYPE è quello di cambiare solo le impostazioni predefinite come cd

record_defaults = [[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]] 

Dopo averlo fatto, se si stampa fuori col1, riceverete questo messaggio.

Tensor("DecodeCSV_43:0", shape=TensorShape([]), dtype=float32) 

Ma c'è un altro errore che si andrà in contro, which has been answered here. Per ricapitolare la risposta, la soluzione è quello di cambiare tf.concat-tf.pack in questo modo.

features = tf.pack([col2, col3, col4, col5, col6, col7, col8, col9, col10, col11]) 
13

L'interfaccia a tf.decode_csv() è un po 'complicata. Il dtype di ogni colonna è determinato dall'elemento corrispondente dell'argomento record_defaults. Il valore per record_defaults nel codice viene interpretato come ogni colonna con tf.int32 come tipo, che genera un errore quando rileva dati in virgola mobile.

Diciamo che avere i seguenti dati CSV, che contiene tre colonne interi, seguite da una colonna in virgola mobile:

4, 8, 9, 4.5 
2, 5, 1, 3.7 
2, 2, 2, 0.1 

Supponendo che tutte le colonne sono necessario, si potrebbe costruire record_defaults come segue:

value = ... 

record_defaults = [tf.constant([], dtype=tf.int32), # Column 0 
        tf.constant([], dtype=tf.int32), # Column 1 
        tf.constant([], dtype=tf.int32), # Column 2 
        tf.constant([], dtype=tf.float32)] # Column 3 

col0, col1, col2, col3 = tf.decode_csv(value, record_defaults=record_defauts) 

assert col0.dtype == tf.int32 
assert col1.dtype == tf.int32 
assert col2.dtype == tf.int32 
assert col3.dtype == tf.float32 

Un valore vuoto in record_defaults indica che il valore è obbligatorio. In alternativa, se (ad esempio) colonna 2 è consentito di avere valori mancanti, definirebbe record_defaults come segue:

record_defaults = [tf.constant([], dtype=tf.int32),  # Column 0 
        tf.constant([], dtype=tf.int32),  # Column 1 
        tf.constant([0], dtype=tf.int32), # Column 2 
        tf.constant([], dtype=tf.float32)] # Column 3 

La seconda parte delle vostre preoccupazioni domanda come costruire e formare un modello che predice il valore di uno dei le colonne dai dati di input. Attualmente, il programma non lo fa: semplicemente concatena le colonne in un singolo tensore, chiamato features. Dovrai definire e addestrare un modello, che interpreti quei dati. Uno dei più semplici di questi approcci è la regressione lineare, e potresti trovare questo tutorial su linear regression in TensorFlow adattabile al tuo problema.