EnsekiTT Blog

EnsekiTTが書くブログです。

IrisのデータをXGBoostで解析してみる話(次回)

こんにちは、えんせきです。
最近は宇宙よりも遠い場所を見ては泣いてます。笑えるシーンとシリアスなシーンのバランスが良いと涙腺に色々と働きかけてきますよね。

つまりなにしたの?

せっかく導入したXGBoostがちゃんと使えるのか試すために、機械学習Hello Worldとも言えるIrisデータ(アヤメの花弁とかのデータ)を使ってアヤメの種類がどれだけ当てられるのか試してみた。特徴量の寄与度合いや木の可視化もしてみる。
f:id:ensekitt:20180218040314j:plain

前回まで

先日データセットの準備を行って学習の準備が完了した。
ensekitt.hatenablog.com

方針

初回

次回

  • ハイパーパラメータ探索しつつ学習する
  • 評価する
  • 変数の重要度を可視化
  • 決定木をプロット

ハイパーパラメータ探索しつつ学習する

# 今回はアヤメの種類を当てるクラス分類なのでXGBClassifier
clf = xgb.XGBClassifier()

# ハイパーパラメータ探索
clf_cv = model_selection.GridSearchCV(clf, {'max_depth': [2,4,6], 'n_estimators': [50,100,200]}, verbose=1)
clf_cv.fit(train_df_x, [i[0] for i in train_df_y.values])
print(clf_cv.best_params_, clf_cv.best_score_)

# 良さげなパラメータで学習しなおす
clf = xgb.XGBClassifier(**clf_cv.best_params_)
clf.fit(train_df_x, [i[0] for i in train_df_y.values])

f:id:ensekitt:20180218035437p:plain
今回はmax_depthは4、n_estimatorsは50が良いとされた様子

評価する

pred = clf.predict(test_df_x)
print(confusion_matrix([i[0] for i in test_df_y.values], pred))
print(classification_report([i[0] for i in test_df_y.values], pred))
#  混同行列
[[19  0  0]
 [ 0 11  1]
 [ 0  2 12]]
# Precision Recall と F値
             precision    recall  f1-score   support

          0       1.00      1.00      1.00        19
          1       0.85      0.92      0.88        12
          2       0.92      0.86      0.89        14

avg / total       0.94      0.93      0.93        45

F値で0.93となかなか良さげな値が出てますね。シードを固定せずに何度か走らせると0.8~0.9前後をうろうろしてた。

変数の重要度を可視化

xgb.plot_importance(clf)
plt.show()

f:id:ensekitt:20180218035905p:plain
花びら(Petal)の長さと幅が重要みたい。

決定木のプロット(1例)

xgb.to_graphviz(clf, num_trees=1)

f:id:ensekitt:20180218035945p:plain
花びらに関して木がつくられていることがわかる。
ツリーの番号を変えればいろんな木によって構成されていることがわかる。

これで一旦XGBoostによるクラス分類ができた。
Scikit-learnに慣れていればあまり抵抗なく使えるみたいで良い。

クリエイティブ・コモンズ・ライセンス
この 作品 は クリエイティブ・コモンズ 表示 4.0 国際 ライセンスの下に提供されています。