首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

GPyOpt,贝叶斯优化的无敌工具!

想写个机器学习模型,参数调起来头都大了?试试GPyOpt吧!这个基于贝叶斯优化的Python库,能帮你自动找到最优参数。它不像传统的网格搜索那样死板,而是会“思考”下一步该试哪组参数,聪明又高效。

1.

安装那些事

装这玩意可得注意了,它依赖 GPy,得先把 GPy 装好:

pip install GPy

pip install GPyOpt

要是装不上,多半是 numpy 版本的问题。建议先把 numpy 降到 1.23 以下:

pip install numpy==1.22.4

2.

简单上手

来看个简单例子,假设咱们要找一个函数的最小值:

import GPyOpt

import numpy as np

def my_func(x):

  # 这是咱要优化的函数

  return (x[:, 0] - 2) ** 2 + (x[:, 1] + 1) ** 2

# 定义参数范围

bounds = [{'name': 'x1', 'type': 'continuous', 'domain': (-5, 5)},

        {'name': 'x2', 'type': 'continuous', 'domain': (-5, 5)}]

# 创建优化器

optimizer = GPyOpt.methods.BayesianOptimization(f=my_func,

                                            domain=bounds)

# 开始优化,最多跑30次

optimizer.run_optimization(max_iter=30)

温馨提示:这代码里的max_iter可别设太小,不然找不到好结果。但也别设太大,会跑很久。

3.

高级玩法

光会基础的多没意思,来点高级的:

# 加入自定义核函数

import GPy

kernel = GPy.kern.Matern52(input_dim=2)

optimizer = GPyOpt.methods.BayesianOptimization(

  f=my_func,

  domain=bounds,

  kernel=kernel,

  acquisition_type='EI'  # 期望改进

)

acquisition_type可选的值还有 ‘MPI’(最大概率改进)和 ‘UCB’(置信区间上界)。不同场景下效果不一样,可以都试试。

4.

可视化结果

要是想看看优化过程,可以整点可视化:

optimizer.plot_acquisition()  # 看收购函数

optimizer.plot_convergence() # 看收敛过程

不过这玩意画出来的图不太好看,建议用 plotly 重新画:

import plotly.graph_objects as go

x = optimizer.X

y = optimizer.Y

fig = go.Figure(data=go.Scatter(x=range(len(y)),

                             y=y.flatten(),

                             mode='lines+markers'))

fig.show()

5.

实战技巧

写了这么多年代码,给大伙分享点经验:

参数范围别设太大,容易找不到最优解

要是优化结果不理想,试试改改初始点

遇到报错先检查数据类型,GPyOpt 对数据类型可挑剔了

CPU吃不消就把num_cores参数调小点

来个更贴近实际的例子:

def train_model(params):

  learning_rate = params[:, 0]

  # 这里放你的模型训练代码

  return validation_error

bounds = [

  {'name': 'lr', 'type': 'continuous', 'domain': (0.0001, 0.1)},

  {'name': 'batch', 'type': 'discrete', 'domain': (16, 32, 64, 128)}

]

有啥不懂的就查查文档,或者去 GitHub 上逛逛 issue,啥问题都能找到答案。贝叶斯优化就是这么好使,比人工调参靠谱多了!

记住一点:优化也要讲究策略。有时候与其较真一个全局最优解,不如找个够用的局部最优解来得实在。代码写简单点,结果够用就行,何必把自己搞那么累。

  • 发表于:
  • 原文链接https://page.om.qq.com/page/O0nl-SY7v3EwAxSpIJbr2iqg0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券