Prima domanda: Sì, la logica è corretta. Il nodo sinistro è Vero e il nodo destro è Falso. Questo è contro-intuitivo; true significa generalmente un valore inferiore.
Seconda domanda: Questo problema è risolto meglio visualizzando l'albero come un grafico con pydotplus. L'attributo 'class_names' di tree.export_graphviz() aggiungerà una dichiarazione di classe alla classe di maggioranza di ciascun nodo. Il codice viene eseguito in iPython.
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf2 = tree.DecisionTreeClassifier()
clf2 = clf2.fit(iris.data, iris.target)
with open("iris.dot", 'w') as f:
f = tree.export_graphviz(clf, out_file=f)
import os
os.unlink('iris.dot')
import pydotplus
dot_data = tree.export_graphviz(clf2, out_file=None)
graph2 = pydotplus.graph_from_dot_data(dot_data)
graph2.write_pdf("iris.pdf")
from IPython.display import Image
dot_data = tree.export_graphviz(clf2, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True, # leaves_parallel=True,
special_characters=True)
graph2 = pydotplus.graph_from_dot_data(dot_data)
## Color of nodes
nodes = graph2.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
values = color[values.index(max(values))]; # print(values)
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
node.set_fillcolor(color)
#
Image(graph2.create_png())
Per quanto riguarda la determinazione della classe la foglia, il tuo esempio non ha foglie con una singola classe, come il set di dati dell'iride fa. Questo è comune e potrebbe richiedere un adattamento eccessivo del modello per ottenere un risultato del genere. Una distribuzione discreta delle classi è il miglior risultato per molti modelli con convalida incrociata.
Godetevi il codice!
Per curiosità, come hai tracciato l'albero decisionale? – Matt
Prima esporta l'albero nel formato JSON (vedi questo [collegamento] (http://www.garysieling.com/blog/rending-scikit-decision-trees-d3-js) e poi traccia l'albero usando d3.js . Oppure puoi usare direttamente la funzione incorporata: 'tree.export_graphviz (clf, out_file = tuo_out_file, feature_names = your_feature_names)' Spero che funzioni, @Matt –