2014-05-09 7 views
14

Ho due problemi con la comprensione del risultato dell'albero decisionale da parte di scikit-learn. Ad esempio, questo è uno dei miei alberi di decisione:come spiegare l'albero delle decisioni da scikit-learn

enter image description here La mia domanda è che come posso usare l'albero?

La prima domanda è che: se un campione soddisfatta la condizione, quindi va al ramo LEFT (se esiste), altrimenti viene RIGHT. Nel mio caso, se un campione con X [7]> 63521.3984. Quindi il campione andrà alla casella verde. Corretta?

La seconda domanda è che: quando un campione raggiunge il nodo foglia, come posso sapere a quale categoria appartiene? In questo esempio, ho tre categorie da classificare. Nella casella rossa, ci sono 91, 212 e 113 campioni sono soddisfatti della condizione, rispettivamente. Ma come posso decidere la categoria? So che esiste una funzione clf.predict (campione) per indicare la categoria. Posso farlo dal grafico ??? Mille grazie.

+1

Per curiosità, come hai tracciato l'albero decisionale? – Matt

+4

Prima esporta l'albero nel formato JSON (vedi questo [collegamento] (http://www.garysieling.com/blog/rending-scikit-decision-trees-d3-js) e poi traccia l'albero usando d3.js . Oppure puoi usare direttamente la funzione incorporata: 'tree.export_graphviz (clf, out_file = tuo_out_file, feature_names = your_feature_names)' Spero che funzioni, @Matt –

risposta

21

La riga value in ciascuna casella indica quanti campioni in quel nodo rientrano in ciascuna categoria, nell'ordine. Ecco perché, in ogni casella, i numeri in value sommano il numero mostrato in sample. Ad esempio, nella casella rossa, 91 + 212 + 113 = 416. Questo significa che se si raggiunge questo nodo, ci sono 91 punti nella categoria 1, 212 nella categoria 2 e 113 nella categoria 3.

Se si prevede di prevedere l'esito per un nuovo punto dati che ha raggiunto tale foglia nell'albero delle decisioni si pronostica la categoria 2, poiché quella è la categoria più comune per i campioni su quel nodo.

+0

Mi interessava sapere quale valore appartiene a quale classe. 'DecisiontreeClassifier.classes' contiene queste informazioni. – ezdazuzena

+0

(Risposta utile: per chiarire usando l'indicizzazione python però: un atterraggio campione nella casella rossa sarebbe previsto (conteggio 212) come categoria 1, piuttosto che categoria 0 (91) o categoria 2 (113) :-)) –

0

In base al libro "Imparare scikit-learn: Machine Learning in Python", l'albero decisionale rappresenta una serie di decisioni basate sui dati di addestramento.

! (http://i.imgur.com/vM9fJLy.png)

classificare un esempio, dobbiamo rispondere alla domanda ad ogni nodo. Ad esempio, il sesso è < = 0,5? (stiamo parlando di una donna?). Se la risposta è sì, si passa al nodo figlio sinistro nell'albero; altrimenti si passa al nodo figlio destro. Continui a rispondere alle domande (era lei in terza classe ?, era lei in prima classe ?, e aveva meno di 13 anni?), Finché non raggiungi una foglia. Quando ci sei, la previsione corrisponde alla classe di destinazione che ha più istanze.

2

Prima domanda: Sì, la logica è corretta. Il nodo sinistro è Vero e il nodo destro è Falso. Questo è contro-intuitivo; true significa generalmente un valore inferiore.

Seconda domanda: Questo problema è risolto meglio visualizzando l'albero come un grafico con pydotplus. L'attributo 'class_names' di tree.export_graphviz() aggiungerà una dichiarazione di classe alla classe di maggioranza di ciascun nodo. Il codice viene eseguito in iPython.

from sklearn.datasets import load_iris 
from sklearn import tree 
iris = load_iris() 
clf2 = tree.DecisionTreeClassifier() 
clf2 = clf2.fit(iris.data, iris.target) 

with open("iris.dot", 'w') as f: 
    f = tree.export_graphviz(clf, out_file=f) 

import os 
os.unlink('iris.dot') 

import pydotplus 
dot_data = tree.export_graphviz(clf2, out_file=None) 
graph2 = pydotplus.graph_from_dot_data(dot_data) 
graph2.write_pdf("iris.pdf") 

from IPython.display import Image 
dot_data = tree.export_graphviz(clf2, out_file=None, 
        feature_names=iris.feature_names, 
        class_names=iris.target_names, 
        filled=True, rounded=True, # leaves_parallel=True, 
        special_characters=True) 
graph2 = pydotplus.graph_from_dot_data(dot_data) 

## Color of nodes 
nodes = graph2.get_node_list() 

for node in nodes: 
    if node.get_label(): 
     values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]; 
     color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],} 
     values = color[values.index(max(values))]; # print(values) 
     color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color) 
     node.set_fillcolor(color) 
# 

Image(graph2.create_png()) 

enter image description here

Per quanto riguarda la determinazione della classe la foglia, il tuo esempio non ha foglie con una singola classe, come il set di dati dell'iride fa. Questo è comune e potrebbe richiedere un adattamento eccessivo del modello per ottenere un risultato del genere. Una distribuzione discreta delle classi è il miglior risultato per molti modelli con convalida incrociata.

Godetevi il codice!