とある技術者の徒然草

生産技術者の適当な日記(統計言語Rに関するメモがメイン)

【R言語】決定木分析の可視化パッケージ決定版?【ggparty】

決定木分析結果の可視化パッケージ【ggparty】

Rの決定木のグラフはどうしても貧弱なのが欠点でした。
その欠点を覆すべく、ggplotを使った自由度の高いパッケージggparty
ができたみたいです。 詳細はこちら。

github.com

library(rpart)
library(partykit)
library(ggplot2)
library(ggparty)

#データを用意する
tmp <- data.frame(Titanic)
df <- data.frame(
  Class = rep(tmp$Class, tmp$Freq),
  Sex = rep(tmp$Sex, tmp$Freq),
  Age = rep(tmp$Age, tmp$Freq),
  Survived = rep(tmp$Survived, tmp$Freq)
)
head(tmp)
head(df)

ct <- rpart(Survived ~ Class + Sex + Age , data = df)
pct <- as.party(ct)
pct

Model formula: Survived ~ Class + Sex + Age

Fitted party: [1] root | [2] Sex in Male | | [3] Age in Adult: No (n = 1667, err = 20.3%) | | [4] Age in Child | | | [5] Class in 3rd: No (n = 48, err = 27.1%) | | | [6] Class in 1st, 2nd: Yes (n = 16, err = 0.0%) | [7] Sex in Female | | [8] Class in 3rd: No (n = 196, err = 45.9%) | | [9] Class in 1st, 2nd, Crew: Yes (n = 274, err = 7.3%)

Number of inner nodes: 4 Number of terminal nodes: 5

g2 <- ggparty(pct, terminal_space = 0.5)
g2 <- g2 + geom_edge(size = 1.5)
g2 <- g2 + geom_edge_label(colour = "grey", size = 6)

##末端のノードに度数の棒グラフを追加
g2 <- g2 +   geom_node_plot(
  shared_legend = FALSE,
  gglist = list(geom_bar(aes(x = Survived,  #!!ct$terms[[2]]の意味が分からない
                             fill = Survived)),
                theme_minimal(),
                theme(legend.position = "none"))
)

##分岐点のボックスを配置
g2 <- g2 + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)

##末端のノードに度数のボックスを追加
g2 <- g2 + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)

g2

f:id:M_taka072:20190804230847j:plain
決定木分析

これはいい。
ほかにも

g <- ggparty(pct, terminal_space = 0.5)
g <- g + geom_edge(size = 1.5)
g <- g + geom_edge_label(colour = "grey", size = 6)
g <- g + geom_node_plot(
  gglist = list(geom_bar(aes(x = "", fill = Survived), position = "fill"), theme_bw(base_size = 15)),
  scales = "fixed",
  id = "terminal",
  shared_axis_labels = TRUE,
  shared_legend = TRUE,
  legend_separator = TRUE,
)
g
g <- g + geom_node_label(
  aes(col = splitvar),
  line_list = list(aes(label = paste("Node", id)),
                   aes(label = splitvar)),
  line_gpar = list(list(
    size = 12,
    col = "black",
    fontface = "bold"
  ),
  list(size = 20)),
  ids = "inner"
)
g
g <- g + geom_node_label(
  aes(label = paste0("Node ", id, ", N = ", nodesize)),
  fontface = "bold",
  ids = "terminal",
  size = 5,
  nudge_y = 0.01
)
g <- g + theme(legend.position = "none")
plot(g)
ggsave(file = "決定木_度数棒グラフ.jpeg", plot = g, dpi = 800, width = 14, height =8)

f:id:M_taka072:20190804231433j:plain
決定木分析

■参考

R ggpartyパッケージを用いた決定木の可視化 | トライフィールズ