首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >扩展xgboost.XGBClassifier

扩展xgboost.XGBClassifier
EN

Stack Overflow用户
提问于 2017-08-30 09:42:59
回答 1查看 1.1K关注 0票数 1

我正在尝试定义一个名为XGBExtended的类,它扩展了类xgboost.XGBClassifier,xgboost的scikit-learn API。我遇到了一些关于get_params方法的问题。下面是一个说明该问题的IPython会话。基本上,get_params似乎只返回我在XGBExtended.__init__中定义的属性,在父初始化方法(xgboost.XGBClassifier.__init__)期间定义的属性被忽略。我正在使用IPython并运行Python2.7。完整的系统规格在底部。

代码语言:javascript
运行
复制
In [182]: import xgboost as xgb
     ...: 
     ...: class XGBExtended(xgb.XGBClassifier):
     ...:   def __init__(self, foo):
     ...:     super(XGBExtended, self).__init__()
     ...:     self.foo = foo
     ...: 
     ...: clf = XGBExtended(foo = 1)
     ...: 
     ...: clf.get_params()
     ...: 
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-182-431c4c3f334b> in <module>()
      8 clf = XGBExtended(foo = 1)
      9 
---> 10 clf.get_params()

/Users/andrewhannigan/lib/xgboost/python-package/xgboost/sklearn.pyc in get_params(self, deep)
    188         if isinstance(self.kwargs, dict):  # if kwargs is a dict, update params accordingly
    189             params.update(self.kwargs)
--> 190         if params['missing'] is np.nan:
    191             params['missing'] = None  # sklearn doesn't handle nan. see #4725
    192         if not params.get('eval_metric', True):

KeyError: 'missing'

所以我遇到了一个错误,因为'missing‘不是XGBClassifier.get_params方法中的params字典中的关键字。我进入调试器来查看:

代码语言:javascript
运行
复制
In [183]: %debug
> /Users/andrewhannigan/lib/xgboost/python-package/xgboost/sklearn.py(190)get_params()
    188         if isinstance(self.kwargs, dict):  # if kwargs is a dict, update params accordingly
    189             params.update(self.kwargs)
--> 190         if params['missing'] is np.nan:
    191             params['missing'] = None  # sklearn doesn't handle nan. see #4725
    192         if not params.get('eval_metric', True):

ipdb> params
{'foo': 1}
ipdb> self.__dict__
{'n_jobs': 1, 'seed': None, 'silent': True, 'missing': nan, 'nthread': None, 'min_child_weight': 1, 'random_state': 0, 'kwargs': {}, 'objective': 'binary:logistic', 'foo': 1, 'max_depth': 3, 'reg_alpha': 0, 'colsample_bylevel': 1, 'scale_pos_weight': 1, '_Booster': None, 'learning_rate': 0.1, 'max_delta_step': 0, 'base_score': 0.5, 'n_estimators': 100, 'booster': 'gbtree', 'colsample_bytree': 1, 'subsample': 1, 'reg_lambda': 1, 'gamma': 0}
ipdb> 

如您所见,params只包含foo变量。但是,对象本身包含xgboost.XGBClassifier.__init__定义的所有参数。但是由于某些原因,从xgboost.XGBClassifier.get_params调用的BaseEstimator.get_params方法只能获得在XGBExtended.__init__方法中显式定义的参数。不幸的是,即使我用deep = True显式调用get_params,它仍然不能正常工作:

代码语言:javascript
运行
复制
ipdb> super(XGBModel, self).get_params(deep=True)
{'foo': 1}
ipdb> 

有人能说出为什么会发生这种情况吗?

系统规格:

代码语言:javascript
运行
复制
In [186]: print IPython.sys_info()
{'commit_hash': u'1149d1700',
 'commit_source': 'installation',
 'default_encoding': 'UTF-8',
 'ipython_path': '/Users/andrewhannigan/virtualenvironment/nimble_ai/lib/python2.7/site-packages/IPython',
 'ipython_version': '5.4.1',
 'os_name': 'posix',
 'platform': 'Darwin-14.5.0-x86_64-i386-64bit',
 'sys_executable': '/usr/local/Cellar/python/2.7.10/Frameworks/Python.framework/Versions/2.7/Resources/Python.app/Contents/MacOS/Python',
 'sys_platform': 'darwin',
 'sys_version': '2.7.10 (default, Jul  3 2015, 12:05:53) \n[GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)]'}
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-08-30 09:55:51

这里的问题是子类的声明不正确。当您仅使用foo声明init方法时,您将覆盖原始方法。它不会自动初始化,即使基类构造函数应该有默认值。

您应该使用以下内容:

代码语言:javascript
运行
复制
class XGBExtended(xgb.XGBClassifier):
    def __init__(self, foo, max_depth=3, learning_rate=0.1,
                 n_estimators=100, silent=True,
                 objective="binary:logistic",
                 nthread=-1, gamma=0, min_child_weight=1,
                 max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1,
                 reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
                 base_score=0.5, seed=0, missing=None, **kwargs):

        # Pass the required parameters to super class
        super(XGBExtended, self).__init__(max_depth, learning_rate,
                                            n_estimators, silent, objective,
                                            nthread, gamma, min_child_weight,
                                            max_delta_step, subsample,
                                            colsample_bytree, colsample_bylevel,
                                            reg_alpha, reg_lambda,
scale_pos_weight, base_score, seed, missing, **kwargs)

        # Use other custom parameters
        self.foo = foo

在此之后,您将不会得到任何错误。

代码语言:javascript
运行
复制
clf = XGBExtended(foo = 1)
print(clf.get_params(deep=True))

>>> {'reg_alpha': 0, 'colsample_bytree': 1, 'silent': True, 
     'colsample_bylevel': 1, 'scale_pos_weight': 1, 'learning_rate': 0.1, 
     'missing': None, 'max_delta_step': 0, 'nthread': -1, 'base_score': 0.5, 
     'n_estimators': 100, 'subsample': 1, 'reg_lambda': 1, 'seed': 0, 
     'min_child_weight': 1, 'objective': 'binary:logistic', 
     'foo': 1, 'max_depth': 3, 'gamma': 0}
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45950630

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档