首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >Softmax算法原理及实现

Softmax算法原理及实现

作者头像
guichen1013
发布2020-08-13 14:56:10
发布2020-08-13 14:56:10
1.2K0
举报
文章被收录于专栏:海边的拾遗者海边的拾遗者

在上一篇的逻辑回归中,主要是用于处理二分类问题,如果面对的是多分类问题,如手写字识别,其中有十个类别,这时候就需要对逻辑回归进行推广,且同样任意两个类之间都是线性可分的。

Softmax Regression

Softmax Regression是Logistic Regression在多分类上的推广,即类标签数量至少为2,也可以用在DNN中最后一层Layer后通过Softmax激活输出。假设有m个训练样本{(X(1),y(1)),(X(2),y(2)),…,(X(m),y(m))},其输入特征为:Rn+1'X(i),标签为{0,1,...,k}'y(i)。假设函数为每一个样本估计其所属的类别的概率P(y=j|X),具体的假设函数为:

其中q为权重向量,对于每一个样本估计其所属的类别的概率为:

同样引入类似逻辑回归中交叉熵损失函数中各类别概率的幂,即指示函数,形式如下:

最终损失函数为:

梯度下降法

这里求解选用迭代法中的梯度下降法来求解,其优点和原理在上一篇中已给出了通俗易懂的解释。qj梯度表达式为:

最后形式为:

现在可以用代码实现训练的具体过程:

代码语言:javascript
复制
def st_gd(feature_data, label_data, k, maxCycle, alpha):
    '''input: feature_data特征
            label_data标签
            k类别的个数
            maxCycle最大的迭代次数
            alpha学习率
    output: weights权重'''
    m, n = np.shape(feature_data)
    weights = np.mat(np.ones((n, k)))
    i = 0
    while i <= maxCycle:
        err = np.exp(feature_data * weights)
        if i % 500 == 0:
            print("\t-----iter: ", i , ", cost: ", cost(err, label_data))
        rowsum = -err.sum(axis=1)
        rowsum = rowsum.repeat(k, axis=1)
        err = err / rowsum
        for x in range(m):
            err[x, label_data[x, 0]] += 1
        weights = weights + (alpha / m) * feature_data.T * err      
        i += 1           
    return weights

其中计算损失函数值的函数为cost,具体实现如下:

代码语言:javascript
复制
def cost(err, label_data):
    '''input: err类别概率
              label_data标签的值
    output: sum_cost / m损失函数的值'''
    m = np.shape(err)[0]
    sum_cost = 0.0
    for i in range(m):
        if err[i, label_data[i, 0]] / np.sum(err[i, :]) > 0:
            sum_cost -= np.log(err[i, label_data[i, 0]] / np.sum(err[i, :]))
        else:
            sum_cost -= 0
    return sum_cost / m

小结

本文介绍的Softmax Regression存在参数冗余的特点,即权重向量减去一个任意向量后对预测结果没有任何影响,也就是说存在多组最优解,而之前提到的Logistic Regression则是本文模型中的k取2时的特殊情况。

到这里整个流程基本就结束了~

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-05-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 海边的拾遗者 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Softmax Regression
  • 梯度下降法
  • 小结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档