Ecco il codice che sto cercando di correre-Come cambio il dtype in TensorFlow per un file csv?
import tensorflow as tf
import numpy as np
import input_data
filename_queue = tf.train.string_input_producer(["cs-training.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.concat(0, [col2, col3, col4, col5, col6, col7, col8, col9, col10, col11])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
print i
example, label = sess.run([features, col1])
try:
print example, label
except:
pass
coord.request_stop()
coord.join(threads)
questo codice di ritorno l'errore sotto.
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-23-e42fe2609a15> in <module>()
7 # Retrieve a single instance:
8 print i
----> 9 example, label = sess.run([features, col1])
10 try:
11 print example, label
/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict)
343
344 # Run request and get response.
--> 345 results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
346
347 # User may have fetched the same tensor multiple times, but we
/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, target_list, fetch_list, feed_dict)
417 # pylint: disable=protected-access
418 raise errors._make_specific_exception(node_def, op, e.error_message,
--> 419 e.code)
420 # pylint: enable=protected-access
421 raise e_type, e_value, e_traceback
InvalidArgumentError: Field 1 in record 0 is not a valid int32: 0.766126609
Ha un sacco di informazioni che lo seguono e che penso sia irrilevante per il problema. Ovviamente il problema è che molti dei dati che sto alimentando al programma non sono del dtype int32. Sono per lo più numeri fluttuanti. Ho provato alcune cose per cambiare il dtype come impostare esplicitamente l'argomento dtype=float
in tf.decode_csv
e tf.concat
. Nessuno dei due ha funzionato. È un argomento non valido. Per finire, non so se questo codice farà effettivamente una previsione sui dati. Voglio pronosticare se col1 sarà un 1 o uno 0 e non vedo nulla nel codice che possa suggerire che realizzerà quella previsione. Forse salverò la domanda per un thread diverso. Qualsiasi aiuto è molto apprezzato!