scikit-learnで機械学習:決定木(2)

前回の続きで、Scikit-learnの決定木を使ってみます。

ここから先は難しいことは何もなくて、本家に従いながら走らせてみる。

もろもろimport。

import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn import tree
from sklearn.grid_search import GridSearchCV
from sklearn import cross_validation
import pydot

データの学習

データはフィッシャーのあやめのサンプルデータを使います。

合計3種類、150個を学習用とテスト用にわけました。

df = pd.read_csv('iris_train.csv')
clf = tree.DecisionTreeClassifier()

#Assign an explanatory variable
data_array = ["SepalLength", "SepalWidth", "PetalLength", "PetalWidth"]

#Assign an objective variable
class_array = ["Name"]

clf = clf.fit(df[data_array], df[class_array])

Fittingがされたので、PDFに決定木を描写させます。

with open('iris.dot', 'w') as f:
  f = tree.export_graphviz(clf, out_file=f)

import os
os.unlink("iris.dot")

from sklearn.externals.six import StringIO  
import pydot

dot_data = StringIO() 
tree.export_graphviz(clf, out_file=dot_data)

graph = pydot.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("iris.pdf") 

で、得られた決定木がこちら。
f:id:tkzs:20150707074704p:plain


この分類器に従って、残り半分のあやめデータを適用してみます。
ちゃんと分類されるんですかね。

分類器をテスト

df_p = pd.read_csv('iris_test.csv')

#Apply test data
result = (df_p[data_array])

print result
['setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa' 'setosa'
 'setosa' 'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
 'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
 'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
 'versicolor' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
 'virginica' 'versicolor' 'versicolor' 'versicolor' 'versicolor'
 'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
 'versicolor' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
 'virginica' 'virginica' 'virginica' 'virginica' 'virginica' 'virginica'
 'virginica' 'versicolor' 'virginica' 'virginica' 'virginica' 'virginica''virginica’] 


精度は94%(47/50)程度でした。
極々簡単にScikit-learnの動作確認でした。