【R言語】決定木分析の可視化パッケージ決定版?【ggparty】
決定木分析結果の可視化パッケージ【ggparty】
Rの決定木のグラフはどうしても貧弱なのが欠点でした。
その欠点を覆すべく、ggplotを使った自由度の高いパッケージggparty
ができたみたいです。 詳細はこちら。
library(rpart) library(partykit) library(ggplot2) library(ggparty) #データを用意する tmp <- data.frame(Titanic) df <- data.frame( Class = rep(tmp$Class, tmp$Freq), Sex = rep(tmp$Sex, tmp$Freq), Age = rep(tmp$Age, tmp$Freq), Survived = rep(tmp$Survived, tmp$Freq) ) head(tmp) head(df) ct <- rpart(Survived ~ Class + Sex + Age , data = df) pct <- as.party(ct) pct
Model formula: Survived ~ Class + Sex + Age
Fitted party: [1] root | [2] Sex in Male | | [3] Age in Adult: No (n = 1667, err = 20.3%) | | [4] Age in Child | | | [5] Class in 3rd: No (n = 48, err = 27.1%) | | | [6] Class in 1st, 2nd: Yes (n = 16, err = 0.0%) | [7] Sex in Female | | [8] Class in 3rd: No (n = 196, err = 45.9%) | | [9] Class in 1st, 2nd, Crew: Yes (n = 274, err = 7.3%)
Number of inner nodes: 4 Number of terminal nodes: 5
g2 <- ggparty(pct, terminal_space = 0.5) g2 <- g2 + geom_edge(size = 1.5) g2 <- g2 + geom_edge_label(colour = "grey", size = 6) ##末端のノードに度数の棒グラフを追加 g2 <- g2 + geom_node_plot( shared_legend = FALSE, gglist = list(geom_bar(aes(x = Survived, #!!ct$terms[[2]]の意味が分からない fill = Survived)), theme_minimal(), theme(legend.position = "none")) ) ##分岐点のボックスを配置 g2 <- g2 + geom_node_label( aes(col = splitvar), line_list = list(aes(label = paste("Node", id)), aes(label = splitvar)), line_gpar = list(list( size = 12, col = "black", fontface = "bold" ), list(size = 20)), ids = "inner" ) ##末端のノードに度数のボックスを追加 g2 <- g2 + geom_node_label( aes(label = paste0("Node ", id, ", N = ", nodesize)), fontface = "bold", ids = "terminal", size = 5, nudge_y = 0.01 ) g2
これはいい。
ほかにも
g <- ggparty(pct, terminal_space = 0.5) g <- g + geom_edge(size = 1.5) g <- g + geom_edge_label(colour = "grey", size = 6) g <- g + geom_node_plot( gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"), theme_bw(base_size = 15)), scales = "fixed", id = "terminal", shared_axis_labels = TRUE, shared_legend = TRUE, legend_separator = TRUE, ) g g <- g + geom_node_label( aes(col = splitvar), line_list = list(aes(label = paste("Node", id)), aes(label = splitvar)), line_gpar = list(list( size = 12, col = "black", fontface = "bold" ), list(size = 20)), ids = "inner" ) g g <- g + geom_node_label( aes(label = paste0("Node ", id, ", N = ", nodesize)), fontface = "bold", ids = "terminal", size = 5, nudge_y = 0.01 ) g <- g + theme(legend.position = "none") plot(g) ggsave(file = "決定木_度数棒グラフ.jpeg", plot = g, dpi = 800, width = 14, height =8)
■参考