撒西不理达纳
,各位!~😭
终于从创伤转出来了,而且在创伤的半年里正好赶上了国自然、省自然的提交。🙃
实在是忙的不行,根本没有时间做自己的事情。🫠
现在转完出来了,也可以写点自己感兴趣的东西了。😜
大家有什么推荐的有趣的包吗,分享一下呀!~🤒
接着之前的机器学习吧,今天是Catboost
。🙊
CatBoost
和XGBoost
、LightGBM
并称为GBDT
的三大主流神器,都是在GBDT
算法框架下的一种改进实现。
CatBoost
是一种基于对称决策树(oblivious trees
)为机器学习器实现的参数较少、支持分类变量
和高准确性
的GBDT
框架,主要解决的痛点是高效合理地处理分类特征
。🥳
rm(list = ls())
library(tidyverse)
library(catboost)
library(survival)
dat <- lung %>%
dplyr::select(c(status,sex, ph.ecog), everything())
DT::datatable(dat)
train_indices <- sample(x = 1:nrow(dat), size = 0.7 * nrow(dat), replace = F)
test_indices <- sample(setdiff(1:nrow(dat), train_indices), size = 0.3 * nrow(dat), replace = F)
train_data <- dat[train_indices, ]
test_data <- dat[test_indices, ]
trainpool <- catboost.load_pool(data=train_data[,-1],label = as.integer(train_data[,1]),cat_features=c(2,3))
testpool <- catboost.load_pool(data=test_data[,-1],label = as.integer(test_data[,1]),cat_features=c(2,3))
params <- list(iterations = 1000,
loss_function = 'Logloss',
random_seed=123,
learning_rate = 0.01,
verbose = 0,
use_best_model = T,
od_type = 'Iter',
od_wait = 10
)
cat_model <- catboost.train(trainpool,testpool,params)
cat_model
pred <- catboost.predict(cat_model,
testpool,
prediction_type = "Probability")
ModelMetrics::confusionMatrix(test_data[,1], pred, cutoff = 0.7)
library(pROC)
cat_roc<- roc(test_data[,1], pred,
aur = T,
ci = T,
smooth = T)
ggroc(cat_roc, legacy.axes = T)+
geom_segment(aes(x = 0, xend = 1, y = 0, yend = 1), color="darkgrey", linetype=4)+
theme_bw()+
ggtitle('ROC') +
ggsci::scale_color_npg()+
annotate("text",x=0.75,y=0.125,label=paste("AUC = ", round(cat_roc$auc,3)))
library(shapviz)
shapviz.catboost.Model <- function(object, X_pred, X = X_pred, collapse = NULL, ...) {
if (!requireNamespace("catboost", quietly = TRUE)) {
stop("Package 'catboost' not installed")
}
stopifnot(
"X must be a matrix or data.frame. It can't be an object of class catboost.Pool" =
is.matrix(X) || is.data.frame(X),
"X_pred must be a matrix, a data.frame, or a catboost.Pool" =
is.matrix(X_pred) || is.data.frame(X_pred) || inherits(X_pred, "catboost.Pool"),
"X_pred must have column names" = !is.null(colnames(X_pred))
)
if (!inherits(X_pred, "catboost.Pool")) {
X_pred <- catboost.load_pool(X_pred)
}
S <- catboost.get_feature_importance(object, X_pred, type = "ShapValues", ...)
pp <- ncol(X_pred) + 1L
baseline <- S[1L, pp]
S <- S[, -pp, drop = FALSE]
colnames(S) <- colnames(X_pred)
shapviz(S, X = X, baseline = baseline, collapse = collapse)
}
shp <- shapviz(cat_model, X_pred = test_data[,-1])
shp
这里之前都介绍过有哪些可视化方法了,大家不清楚的可以翻看之前的推文。🥳
sv_waterfall(shp,row_id = 1)
sv_force(shp,row_id = 1)
sv_importance(shp,kind = "beeswarm")
sv_importance(shp,fill="#F2613F")
sv_dependence(shp,
"ph.ecog",
alpha = 0.5,
size = 1.5,
color_var = NULL)
sv_dependence(shp,
v = c("sex",
"age",
"ph.ecog",
"ph.karno"))
最后祝大家早日不卷!~