【R言語】caretでの独自評価指標(マシューズ相関係数MCCとF1)を利用した学習モデルの構築
caretによる機械学習モデルの評価にマシューズ係数を採用する
マシューズ相関係数とは機械学習の2値分類問題で、
正と負の割合が不均衡の場合に用いられる評価指標である。
製造業における不良解析でも工程中のデータは通常不均衡データであり、
この指標を採用した学習モデルをつくらないと、
不良の特徴量が反映されないのでは考える。
詳しい定義は下記が参考となった。
マシューズ相関係数とは - Matthews Correlation Coefficient -
今回はcaretでカスタム評価指標をつくり、
F1値とMCCの算出を試みた。
Rにおける実装の際に参考にしたのは以下の本である。
データ分析プロセス / 金 明哲 編 福島 真太朗 著 | 共立出版
library(mlbench) library(caret) library(tidyverse) library(doParallel) library(pROC) data(Sonar) dim(Sonar) str(Sonar) inTrain <- createDataPartition(Sonar$Class,p=0.8,list=FALSE) nrow(inTrain)/nrow(Sonar) Sonar_train <- Sonar[inTrain,] Sonar_test <- Sonar[-inTrain,] ###https://shumagit.github.io/myblog/2018/09/15/learn-caret/ ## グリッドサーチ用のdata.frame par.grid <- expand.grid( mtry = c(3:12), splitrule = c("gini"), #"gini" "variance"だとエラー, min.node.size = c(3,4,5,6), stringsAsFactors = FALSE ) print(par.grid) #並列処理設定 detectCores() cl <- makePSOCKcluster(detectCores()) registerDoParallel(cl) # PrecisionやRecallなどを評価指標とする評価関数 my.summary <- function(data, lev = NULL, model = NULL) { if (is.character(data$obs)) { data$obs <- factor(data$obs, levels = lev) } conf <- table(data$pred, data$obs) # 混合行列 prec <- conf[1, 1]/sum(conf[1, ]) # Precsion conf[1, 1]=TP conf[1, 2]=FP conf[2, 1]=FN rec <- conf[1, 1]/sum(conf[, 1]) # Recall f.value <- 2 * prec * rec/(prec + rec) # F値 acc <- sum(diag(conf))/sum(conf) # Accuracy TP <- conf[1, 1] FP <- conf[1, 2] FN <- conf[2, 1] TN <- conf[2, 2] Mcc <-((TP*TN)-(FP*FN))/sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)) out <- c(Precision = prec, Recall = rec, F = f.value, Accuracy = acc , MCC = Mcc) #マシューズ相関係数Mcc out } tc3 <- trainControl( method = "cv", number = 5, classProbs = T, selectionFunction = "best", summaryFunction = my.summary ) ## 繰り返し交差検証でモデル評価 tc3_1 <- trainControl( method = "repeatedcv", number = 5, repeats =10, classProbs = T, selectionFunction = "best", summaryFunction = my.summary ) fit_ranger3 <- caret::train(Class~., data = Sonar_train, method = "ranger", tuneGrid = par.grid, trControl = tc3_1, num.trees = 500, importance = "permutation", metric = "MCC" ) ecuteParallelProcess() print(fit_ranger3) imp <- varImp(fit_ranger3) df_imp <- imp$importance %>% rownames_to_column() %>% arrange(desc(Overall))
- 統計解析言語SのクローンとしてGNUのもとで開発が進められている、統計解析・可視化のための言語・環境。 GPLライセンスのオープンソースソフトウェアである。 フリーソフトながらも非常に高性能で、標準的な統計手法はほぼすべて簡単なコマンドで実行できる.. 続きを読む
- R: The R Project for Statistical Computing www.r-project.org
- このキーワードを含むブログを見る