こんにちは、えんせきです。
最近は宇宙よりも遠い場所を見ては泣いてます。笑えるシーンとシリアスなシーンのバランスが良いと涙腺に色々と働きかけてきますよね。
つまりなにしたの?
せっかく導入したXGBoostがちゃんと使えるのか試すために、機械学習のHello Worldとも言えるIrisデータ(アヤメの花弁とかのデータ)を使ってアヤメの種類がどれだけ当てられるのか試してみた。特徴量の寄与度合いや木の可視化もしてみる。
前回まで
先日データセットの準備を行って学習の準備が完了した。
ensekitt.hatenablog.com
方針
初回
次回
- ハイパーパラメータ探索しつつ学習する
- 評価する
- 変数の重要度を可視化
- 決定木をプロット
ハイパーパラメータ探索しつつ学習する
# 今回はアヤメの種類を当てるクラス分類なのでXGBClassifier clf = xgb.XGBClassifier() # ハイパーパラメータ探索 clf_cv = model_selection.GridSearchCV(clf, {'max_depth': [2,4,6], 'n_estimators': [50,100,200]}, verbose=1) clf_cv.fit(train_df_x, [i[0] for i in train_df_y.values]) print(clf_cv.best_params_, clf_cv.best_score_) # 良さげなパラメータで学習しなおす clf = xgb.XGBClassifier(**clf_cv.best_params_) clf.fit(train_df_x, [i[0] for i in train_df_y.values])
今回はmax_depthは4、n_estimatorsは50が良いとされた様子
評価する
pred = clf.predict(test_df_x) print(confusion_matrix([i[0] for i in test_df_y.values], pred)) print(classification_report([i[0] for i in test_df_y.values], pred))
# 混同行列 [[19 0 0] [ 0 11 1] [ 0 2 12]] # Precision Recall と F値 precision recall f1-score support 0 1.00 1.00 1.00 19 1 0.85 0.92 0.88 12 2 0.92 0.86 0.89 14 avg / total 0.94 0.93 0.93 45
F値で0.93となかなか良さげな値が出てますね。シードを固定せずに何度か走らせると0.8~0.9前後をうろうろしてた。
変数の重要度を可視化
xgb.plot_importance(clf) plt.show()
花びら(Petal)の長さと幅が重要みたい。
決定木のプロット(1例)
xgb.to_graphviz(clf, num_trees=1)
花びらに関して木がつくられていることがわかる。
ツリーの番号を変えればいろんな木によって構成されていることがわかる。
これで一旦XGBoostによるクラス分類ができた。
Scikit-learnに慣れていればあまり抵抗なく使えるみたいで良い。