我是机器学习和python的新手。现在我正在尝试应用随机森林来预测目标的二进制结果。在我的数据中,我有24个预测因子(1000个观察值),其中一个是分类的(性别),其他的都是数字的。我已经对大尺度特征进行了转换并进行了估算。最后,我检查了相关性和共线性,并在此基础上删除了一些特征(结果是我有24个特征)。现在,当我实现RF时,它在训练集中总是完美的,而根据交叉验证,比率并不是很好。即使在测试集中应用它,它也会给出非常非常低的召回值。我该如何解决这个问题?
def classification_model(model, data, predictors, outcome):
# Fit the model:
model.fit(data[predictors], data[outcome])
# Make predictions on training set:
predictions = model.predict(data[predictors])
# Print accuracy
accuracy = metrics.accuracy_score(predictions, data[outcome])
print("Accuracy : %s" % "{0:.3%}".format(accuracy))
# Perform k-fold cross-validation with 5 folds
kf = KFold(data.shape[0], n_folds=5)
error = []
for train, test in kf:
# Filter training data
train_predictors = (data[predictors].iloc[train, :])
# The target we're using to train the algorithm.
train_target = data[outcome].iloc[train]
# Training the algorithm using the predictors and target.
model.fit(train_predictors, train_target)
# Record error from each cross-validation run
error.append(model.score(data[predictors].iloc[test, :], data[outcome].iloc[test]))
print("Cross-Validation Score : %s" % "{0:.3%}".format(np.mean(error)))
# Fit the model again so that it can be refered outside the function:
model.fit(data[predictors], data[outcome])
outcome_var = 'Sold'
model = RandomForestClassifier(n_estimators=20)
predictor_var = train.drop('Sold', axis=1).columns.values
classification_model(model,train,predictor_var,outcome_var)
#Create a series with feature importances:
featimp = pd.Series(model.feature_importances_, index=predictor_var).sort_values(ascending=False)
print(featimp)
outcome_var = 'Sold'
model = RandomForestClassifier(n_estimators=20, max_depth=20, oob_score = True)
predictor_var = ['fet1','fet2','fet3','fet4']
classification_model(model,train,predictor_var,outcome_var) 发布于 2017-05-24 05:07:00
在随机森林中,它很容易过度拟合。要解决这个问题,您需要更严格地进行参数搜索,以了解要使用的最佳参数。Here是关于如何做到这一点的链接:(来自scikit文档)。
它是过拟合的,您需要搜索将在模型上工作的最佳参数。该链接提供了网格和随机搜索的实现,以进行超参数估计。通过麻省理工学院的人工智能讲座获得更深层次的理论导向也将是一件很有趣的事情:https://www.youtube.com/watch?v=UHBmv7qCey4&t=318s。
希望这能有所帮助!
https://stackoverflow.com/questions/44140744
复制相似问题