Page List

Search on the blog

2015年5月1日金曜日

scikit-learn(2) SVMのパラメータチューニング

 昨日に引き続き、scikit-learnの勉強。

今回は”手書き数字”を識別するデータセット「digits」にチャレンジした。
このデータセットは「iris」とは異なり、デフォルトパラメータで識別を行なっても良い結果はえられない。

ということでパラメータのチューニングを行わなければならない。
パラメータの候補を渡すとcross validationを行って良いパラメータを選んでくれる関数があったのでそれを使用した。
from sklearn import svm, datasets, grid_search, metrics
from numpy import logspace

# load data
digits = datasets.load_digits()
X = digits.data
Y = digits.target

# split data into (training, test)
train_data_num = 1000
trainX, trainY = X[:train_data_num], Y[:train_data_num]
testX, testY = X[train_data_num:], Y[train_data_num:]

# train SVM
# use cross validation to choose good parameters from grid points
parameters = {
    'C' : logspace(-10, 10, base=2),
    'gamma' : logspace(-10, 10, base=2)
}
 
grsrch = grid_search.GridSearchCV(svm.SVC(), parameters)
grsrch.fit(trainX[:100], trainY[:100])

# get an estimator with the best parameters
clf = grsrch.best_estimator_
clf.fit(trainX, trainY)

# predict test data
predictY = clf.predict(testX)
print metrics.classification_report(testY, predictY)
print metrics.confusion_matrix(testY, predictY)
パラメータCの候補は、2^{-10}, 2^{-9}, 2^{-8}, .... 2^{10}、
パラメータgammaの候補は、2^{-10}, 2^{-9}, 2^{-8}, .... 2^{10}
とした。

まず学習データのうち100個だけを使って、パラメータチューニングを実施した。 その後最良のパラメータを持つSVMをすべての学習データで学習させた。

 結果は、以下のとおり。テストデータの識別率は97%。

             precision    recall  f1-score   support

          0       1.00      0.99      0.99        79
          1       0.99      0.96      0.97        80
          2       0.99      0.99      0.99        77
          3       0.97      0.86      0.91        79
          4       0.99      0.95      0.97        83
          5       0.94      0.99      0.96        82
          6       0.99      0.99      0.99        80
          7       0.95      0.99      0.97        80
          8       0.94      1.00      0.97        76
          9       0.94      0.98      0.96        81

avg / total       0.97      0.97      0.97       797

[[78  0  0  0  1  0  0  0  0  0]
 [ 0 77  1  0  0  0  0  0  1  1]
 [ 0  0 76  1  0  0  0  0  0  0]
 [ 0  0  0 68  0  3  0  4  4  0]
 [ 0  0  0  0 79  0  0  0  0  4]
 [ 0  0  0  0  0 81  1  0  0  0]
 [ 0  1  0  0  0  0 79  0  0  0]
 [ 0  0  0  0  0  1  0 79  0  0]
 [ 0  0  0  0  0  0  0  0 76  0]
 [ 0  0  0  1  0  1  0  0  0 79]]
done.

0 件のコメント:

コメントを投稿