10分钟
模型接口
示例:
_label_map={
# 'Iris-setosa':0,
'Iris-versicolor':0,
'Iris-virginica':1
}
class BoosterTest:
def __init__(self):
df = pd.read_csv('./data/iris.csv')
_feature_names = ['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width']
x = df[_feature_names]
y = df['Class'].map(lambda x: _label_map[x])
train_X, test_X, train_Y, test_Y = train_test_split(x, y, test_size=0.3,
stratify=y, shuffle=True, random_state=1)
print([item.shape for item in (train_X, test_X, train_Y, test_Y)])
self._train_set = lgb.Dataset(data=train_X, label=train_Y,
feature_name=_feature_names)
self._validate_set = lgb.Dataset(data=test_X, label=test_Y,
reference=self._train_set)
self._booster = lgb.Booster(params={
'boosting': 'gbdt',
'verbosity': 1, # 打印消息
'learning_rate': 0.1, # 学习率
'num_leaves':5,
'max_depth': 5,
'objective': 'binary',
'metric': 'auc',
'seed': 321,
},
train_set=self._train_set)
self._booster.add_valid(self._validate_set,'validate1')
self._booster.set_train_data_name('trainAAAAA')
def print_attr(self):
print('feature name:',self._booster.feature_name())
# feature name: ['Sepal_Length', 'Sepal_Width', 'Petal_Length', 'Petal_Width']
print('feature nums:', self._booster.num_feature())
# feature nums: 4
def test_train(self):
for i in range(0,4):
self._booster.update(self._train_set)
print('after iter:%d'%self._booster.current_iteration())
print('train eval:',self._booster.eval(self._train_set, name='train'))
print('test eval:',self._booster.eval(self._validate_set,name='eval'))
# after iter:1
# train eval: [('train', 'auc', 0.9776530612244898, True)]
# test eval: [('eval', 'auc', 0.9783333333333334, True)]
# after iter:2
# train eval: [('train', 'auc', 0.9907142857142858, True)]
# test eval: [('eval', 'auc', 0.9872222222222222, True)]
# after iter:3
# train eval: [('train', 'auc', 0.9922448979591837, True)]
# test eval: [('eval', 'auc', 0.9888888888888889, True)]
# after iter:4
# train eval: [('train', 'auc', 0.9922448979591837, True)]
# test eval: [('eval', 'auc', 0.9888888888888889, True)]
def test(self):
self.print_attr()
self.test_train()
学员评价