`
BradyZhu
  • 浏览: 247561 次
  • 性别: Icon_minigender_1
  • 来自: 上海
社区版块
存档分类
最新评论

深度学习与计算机视觉系列(7)_神经网络数据预处理,正则化与损失函数

 
阅读更多

作者:寒小阳 && 龙心尘
时间:2016年1月。
出处:http://blog.csdn.net/han_xiaoyang/article/details/50451460
声明:版权所有,转载请联系作者并注明出处

1. 引言

上一节我们讲完了各种激励函数的优缺点和选择,以及网络的大小以及正则化对神经网络的影响。这一节我们讲一讲输入数据预处理、正则化以及损失函数设定的一些事情。

2. 数据与网络的设定

前一节提到前向计算涉及到的组件(主要是神经元)设定。神经网络结构和参数设定完毕之后,我们就得到得分函数/score function(忘记的同学们可以翻看一下之前的博文),总体说来,一个完整的神经网络就是在不断地进行线性映射(权重和input的内积)和非线性映射(部分激励函数作用)的过程。这一节我们会展开来讲讲数据预处理权重初始化损失函数的事情。

2.1 数据预处理

在卷积神经网处理图像问题的时候,图像数据有3种常见的预处理可能会用到,如下。我们假定数据表示成矩阵为X,其中我们假定X是[N*D]维矩阵(N是样本数据量,D为单张图片的数据向量长度)。

  • 去均值,这是最常见的图片数据预处理,简单说来,它做的事情就是,对待训练的每一张图片的特征,都减去全部训练集图片的特征均值,这么做的直观意义就是,我们把输入数据各个维度的数据都中心化到0了。使用python的numpy工具包,这一步可以用X -= np.mean(X, axis = 0)轻松实现。当然,其实这里也有不同的做法:简单一点,我们可以直接求出所有像素的均值,然后每个像素点都减掉这个相同的值;稍微优化一下,我们在RGB三个颜色通道分别做这件事。
  • 归一化,归一化的直观理解含义是,我们做一些工作去保证所有的维度上数据都在一个变化幅度上。通常我们有两种方法来实现归一化。一个是在数据都去均值之后,每个维度上的数据都除以这个维度上数据的标准差(X /= np.std(X, axis = 0))。另外一种方式是我们除以数据绝对值最大值,以保证所有的数据归一化后都在-1到1之间。多说一句,其实在任何你觉得各维度幅度变化非常大的数据集上,你都可以考虑归一化处理。不过对于图像而言,其实这一步反倒可做可不做,因为大家都知道,像素的值变化区间都在[0,255]之间,所以其实图像输入数据天生幅度就是一致的。

上述两个操作对于数据的作用,画成示意图,如下:
数据的去均值与归一化

  • PCA和白化/whitening,这是另外一种形式的数据预处理。在经过去均值操作之后,我们可以计算数据的协方差矩阵,从而可以知道数据各个维度之间的相关性,简单示例代码如下:
# 假定输入数据矩阵X是[N*D]维的
X -= np.mean(X, axis = 0) # 去均值
cov = np.dot(X.T, X) / X.shape[0] # 计算协方差

得到的结果矩阵中元素(i,j)表示原始数据中,第i维和第j维直接爱你的相关性。有意思的是,其实协方差矩阵的对角线包含了每个维度的变化幅度。另外,我们都知道协方差矩阵是对称的,我们可以在其上做矩阵奇异值分解(SVD factorization):

U,S,V = np.linalg.svd(cov)

其中U为特征向量,我们如果相对原始数据(去均值之后)做去相关操作,只需要进行如下运算:

Xrot = np.dot(X, U)

这么理解一下可能更好,U是一组正交基向量。所以我们可以看做把原始数据X投射到这组维度保持不变的正交基底上,从而也就完成了对原始数据的去相关。如果去相关之后你再求一下Xrot的协方差矩阵,你会发现这时候的协方差矩阵是一个对角矩阵了。而numpy中的np.linalg.svd更好的一个特性是,它返回的U是对特征值排序过的,这也就意味着,我们可以用它进行降维操作。我们可以只取top的一些特征向量,然后做和原始数据做矩阵乘法,这个时候既降维减少了计算量,同时又保存下了绝大多数的原始数据信息,这就是所谓的主成分分析/PCA

Xrot_reduced = np.dot(X, U[:,:100])

这个操作之后,我们把原始数据集矩阵从[N*D]降维到[N*100],保存了前100个能包含绝大多数数据信息的维度。实际应用中,你在PCA降维之后的数据集上,做各种机器学习的训练,在节省空间和时间的前提下,依旧能有很好的训练准确度。

最后我们再提一下whitening操作。所谓whitening,就是把各个特征轴上的数据除以特征向量,从而达到在每个特征轴上都归一化幅度的结果。whitening变换的几何意义和理解是,如果输入的数据是多变量高斯,那whitening之后的 数据是一个均值为0而不同方差的高斯矩阵。这一步简单代码实现如下:

#白化数据
Xwhite = Xrot / np.sqrt(S + 1e-5)

提个醒:whitening操作会有严重化噪声的可能。注意到我们在上述代码中,分母的部分加入了一个很小的数1e-5,以防止出现除以0的情况。但是数据中的噪声部分可能会因whitening操作而变大,因为这个操作的本质是把输入的每个维度都拉到差不多的幅度,那么本不相关的有微弱幅度变化的噪声维度,也被拉到了和其他维度同样的幅度。当然,我们适当提高坟墓中的安全因子(1e-5)可以在一定程度上缓解这个问题。

下图为原始数据到去相关白化之后的数据分布示意图:
去相关与白化

我们来看看真实数据集上的操作与得到的结果,也许能对这些过程有更清晰一些的认识。大家都还记得CIFAR-10图像数据集吧。训练集大小为50000*3072,也就是说,每张图片都被展成一个3072维度的列向量了。然后我们对原始50000*3072数据矩阵做SVD分解,进行上述一些操作,再可视化一下,得到的结果示意图如下:

CIFAR-10降维与可视化

我们稍加解释一下,最左边是49张原始图片;左起第2幅图是最3072个特征向量中最top的144个,这144个特征向量包含了绝大多数数据变量信息,而其实它们代表的是图片中低频的信息;左起第3幅图表示PCA降维操作之后的49张图片,使用上面求得的144个特征向量。我们可以观察到图片好像被蒙上了一层东西一样,模糊化了,这也就表明了我们的top144个特征向量捕捉到的都是图像的低频信息,不过我们发现图像的绝大多数信息确实被保留下来了;最右图是whitening的144个数通过乘以U.transpose()[:144,:]还原回图片的样子,有趣的是,我们发现,现在低频信息基本都被滤掉了,剩下一些高频信息被放大呈现。

实际工程中,因为这个部分讲到数据预处理,我们就把基本的几种数据预处理都讲了一遍,但实际卷积神经网中,我们并没有用到去相关和whitening操作。当然,去均值是非常非常重要的,而每个像素维度的归一化也是常用的操作。

特别说明,需要特别说明的一点是,上述的预处理操作,一定都是在训练集上先预算的,然后应用在交叉验证/测试集上的。举个例子,有些同学会先把所有的图片放一起,求均值,然后减掉均值,再把这份数据分作训练集和测试集,这是不对的亲!!!

2.2 权重初始化

我们之前已经看过一个完整的神经网络,是怎么样通过神经元和连接搭建起来的,以及如何对数据做预处理。在训练神经网络之前,我们还有一个任务要做,那就是初始化参数。

错误的想法:全部初始化为0,有些同学说,那既然要训练和收敛嘛,初始值就随便设定,简单一点就全设为0好了。亲,这样是绝对不行的!!!为啥呢?我们在神经网络训练完成之前,是不可能预知神经网络最后的权重具体结果的,但是根据我们归一化后的数据,我们可以假定,大概有半数左右的权重是正数,而另外的半数是负数。但设定全部初始权重都为0的结果是,网络中每个神经元都计算出一样的结果,然后在反向传播中有一样的梯度结果,因此迭代之后的变化情况也都一样,这意味着这个神经网络的权重没有办法差异化,也就没有办法学习到东西。

很小的随机数,其实我们依旧希望初始的权重是较小的数,趋于0,但是就像我们刚刚讨论过的一样,不要真的是0。综合上述想法,在实际场景中,我们通常会把初始权重设定为非常小的数字,然后正负尽量一半一半。这样,初始的时候权重都是不一样的很小随机数,然后迭代过程中不会再出现迭代一致的情况。举个例子,我们可能可以这样初始化一个权重矩阵W=0.0001*np.random.randn(D,H)。这个初始化的过程,使得每个神经元的权重向量初始化为多维高斯中的随机采样向量,所以神经元的初始权重值指向空间中的随机方向。

特别说明:其实不一定更小的初始值会比大值有更好的效果。我们这么想,一个有着非常小的权重的神经网络在后向传播过程中,回传的梯度也是非常小的。这样回传的”信号”流会相对也较弱,对于层数非常多的深度神经网络,这也是一个问题,回传到最前的迭代梯度已经很小了。

方差归一化,上面提到的建议有一个小问题,对于随机初始化的神经元参数下的输出,其分布的方差随着输入的数量,会增长。我们实际上可以通过除以总输入数目的平方根,归一化每个神经元的输出方差到1。也就是说,我们倾向于初始化神经元的权重向量为w = np.random.randn(n) / sqrt(n),其中n为输入数。

我们从数学的角度,简单解释一下,为什么上述操作可以归一化方差。考虑在激励函数之前的权重w与输入x的内积<nobr><span class="math" id="MathJax-Span-1" style="width: 6.349em; display: inline-block;"><span style="display: inline-block; position: relative; width: 5.069em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.603em 1000em 3.043em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-2"><span class="mi" id="MathJax-Span-3" style="font-family: STIXGeneral-Italic;">s</span><span class="mo" id="MathJax-Span-4" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="munderover" id="MathJax-Span-5" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 1.389em; height: 0px;"><span style="position: absolute; clip: rect(1.656em 1000em 2.989em -0.424em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-6" style="font-family: STIXGeneral-Regular; vertical-align: 0.003em;">∑</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(1.923em 1000em 2.563em -0.477em); top: -2.877em; left: 0.963em;"><span class="mi" id="MathJax-Span-7" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.563em -0.424em); top: -2.077em; left: 0.963em;"><span class="mi" id="MathJax-Span-8" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-9" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-10" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-11" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-12"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-13" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-14" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.537em; vertical-align: -0.463em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-1">s = \sum_i^n w_i x_i</script>部分,我们计算一下<nobr><span class="math" id="MathJax-Span-15" style="width: 0.536em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.429em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-16"><span class="mi" id="MathJax-Span-17" style="font-family: STIXGeneral-Italic;">s</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 0.737em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-2">s</script>的方差:

<nobr><span class="math" id="MathJax-Span-18" style="width: 33.336em; display: inline-block;"><span style="display: inline-block; position: relative; width: 26.669em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(-4.957em 1000em 9.176em -0.317em); top: -2.344em; left: 0.003em;"><span class="mrow" id="MathJax-Span-19"><span class="mtable" id="MathJax-Span-20" style="padding-right: 0.163em; padding-left: 0.163em;"><span style="display: inline-block; position: relative; width: 26.296em; height: 0px;"><span style="position: absolute; clip: rect(7.309em 1000em 20.323em -0.477em); top: -13.704em; left: 0.003em;"><span style="display: inline-block; position: relative; width: 2.563em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.883em -0.477em); top: -8.104em; right: 0.003em;"><span class="mtd" id="MathJax-Span-21"><span class="mrow" id="MathJax-Span-22"><span class="mtext" id="MathJax-Span-23" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-24" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-25" style="font-family: STIXGeneral-Italic;">s</span><span class="mo" id="MathJax-Span-26" style="font-family: STIXGeneral-Regular;">)</span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(3.843em 1000em 4.163em -0.477em); top: -6.344em; right: 0.003em;"><span class="mtd" id="MathJax-Span-43"><span class="mrow" id="MathJax-Span-44"></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; clip: rect(3.843em 1000em 4.163em -0.477em); top: -3.144em; right: 0.003em;"><span class="mtd" id="MathJax-Span-61"><span class="mrow" id="MathJax-Span-62"></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; clip: rect(3.843em 1000em 4.163em -0.477em); top: 0.056em; right: 0.003em;"><span class="mtd" id="MathJax-Span-116"><span class="mrow" id="MathJax-Span-117"></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; clip: rect(3.843em 1000em 4.163em -0.477em); top: 2.456em; right: 0.003em;"><span class="mtd" id="MathJax-Span-137"><span class="mrow" id="MathJax-Span-138"></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span><span style="display: inline-block; width: 0px; height: 13.709em;"></span></span><span style="position: absolute; clip: rect(7.523em 1000em 21.656em -0.157em); top: -14.824em; left: 2.563em;"><span style="display: inline-block; position: relative; width: 23.736em; height: 0px;"><span style="position: absolute; clip: rect(0.803em 1000em 4.003em -0.157em); top: -8.104em; left: 0.003em;"><span class="mtd" id="MathJax-Span-27"><span class="mrow" id="MathJax-Span-28"><span class="mo" id="MathJax-Span-29" style="font-family: STIXGeneral-Regular; padding-left: 0.269em; padding-right: 0.376em;">=</span><span class="mtext" id="MathJax-Span-30" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-31" style="font-family: STIXGeneral-Regular;">(</span><span class="munderover" id="MathJax-Span-32"><span style="display: inline-block; position: relative; width: 1.336em; height: 0px;"><span style="position: absolute; clip: rect(1.869em 1000em 3.629em -0.371em); top: -2.984em; left: 0.003em;"><span class="mo" id="MathJax-Span-33" style="font-family: STIXSizeOneSym; vertical-align: -0.531em;">∑</span><span style="display: inline-block; width: 0px; height: 2.989em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.669em -0.424em); top: -1.224em; left: 0.536em;"><span class="mi" id="MathJax-Span-34" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(1.816em 1000em 2.563em -0.477em); top: -3.571em; left: 0.483em;"><span class="mi" id="MathJax-Span-35" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-36" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-37" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-38" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-39"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-40" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-41" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-42" style="font-family: STIXGeneral-Regular;">)</span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(0.803em 1000em 4.003em -0.157em); top: -4.904em; left: 0.003em;"><span class="mtd" id="MathJax-Span-45"><span class="mrow" id="MathJax-Span-46"><span class="mo" id="MathJax-Span-47" style="font-family: STIXGeneral-Regular; padding-left: 0.269em; padding-right: 0.376em;">=</span><span class="munderover" id="MathJax-Span-48"><span style="display: inline-block; position: relative; width: 1.336em; height: 0px;"><span style="position: absolute; clip: rect(1.869em 1000em 3.629em -0.371em); top: -2.984em; left: 0.003em;"><span class="mo" id="MathJax-Span-49" style="font-family: STIXSizeOneSym; vertical-align: -0.531em;">∑</span><span style="display: inline-block; width: 0px; height: 2.989em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.669em -0.424em); top: -1.224em; left: 0.536em;"><span class="mi" id="MathJax-Span-50" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(1.816em 1000em 2.563em -0.477em); top: -3.571em; left: 0.483em;"><span class="mi" id="MathJax-Span-51" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mtext" id="MathJax-Span-52" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">Var</span><span class="mo" id="MathJax-Span-53" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-54"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-55" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-56" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-57"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-58" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-59" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-60" style="font-family: STIXGeneral-Regular;">)</span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(0.803em 1000em 4.003em -0.157em); top: -1.704em; left: 0.003em;"><span class="mtd" id="MathJax-Span-63"><span class="mrow" id="MathJax-Span-64"><span class="mo" id="MathJax-Span-65" style="font-family: STIXGeneral-Regular; padding-left: 0.269em; padding-right: 0.376em;">=</span><span class="munderover" id="MathJax-Span-66"><span style="display: inline-block; position: relative; width: 1.336em; height: 0px;"><span style="position: absolute; clip: rect(1.869em 1000em 3.629em -0.371em); top: -2.984em; left: 0.003em;"><span class="mo" id="MathJax-Span-67" style="font-family: STIXSizeOneSym; vertical-align: -0.531em;">∑</span><span style="display: inline-block; width: 0px; height: 2.989em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.669em -0.424em); top: -1.224em; left: 0.536em;"><span class="mi" id="MathJax-Span-68" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(1.816em 1000em 2.563em -0.477em); top: -3.571em; left: 0.483em;"><span class="mi" id="MathJax-Span-69" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-70" style="font-family: STIXGeneral-Regular;">[</span><span class="mi" id="MathJax-Span-71" style="font-family: STIXGeneral-Italic;">E<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-72" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-73"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-74" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-75" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-76" style="font-family: STIXGeneral-Regular;">)</span><span class="msubsup" id="MathJax-Span-77"><span style="display: inline-block; position: relative; width: 0.803em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.883em -0.424em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-78" style="font-family: STIXGeneral-Regular;">]</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.824em; left: 0.376em;"><span class="mn" id="MathJax-Span-79" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mtext" id="MathJax-Span-80" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-81" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-82"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-83" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-84" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-85" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-86" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="mi" id="MathJax-Span-87" style="font-family: STIXGeneral-Italic; padding-left: 0.269em;">E<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-88" style="font-family: STIXGeneral-Regular;">[</span><span class="mo" id="MathJax-Span-89" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-90"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-91" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-92" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-93" style="font-family: STIXGeneral-Regular;">)</span><span class="msubsup" id="MathJax-Span-94"><span style="display: inline-block; position: relative; width: 0.803em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.883em -0.424em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-95" style="font-family: STIXGeneral-Regular;">]</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.824em; left: 0.376em;"><span class="mn" id="MathJax-Span-96" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mtext" id="MathJax-Span-97" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-98" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-99"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-100" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-101" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-102" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-103" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="mtext" id="MathJax-Span-104" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">Var</span><span class="mo" id="MathJax-Span-105" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-106"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-107" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-108" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-109" style="font-family: STIXGeneral-Regular;">)</span><span class="mtext" id="MathJax-Span-110" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-111" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-112"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-113" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-114" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-115" style="font-family: STIXGeneral-Regular;">)</span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(0.803em 1000em 4.003em -0.157em); top: 1.496em; left: 0.003em;"><span class="mtd" id="MathJax-Span-118"><span class="mrow" id="MathJax-Span-119"><span class="mo" id="MathJax-Span-120" style="font-family: STIXGeneral-Regular; padding-left: 0.269em; padding-right: 0.376em;">=</span><span class="munderover" id="MathJax-Span-121"><span style="display: inline-block; position: relative; width: 1.336em; height: 0px;"><span style="position: absolute; clip: rect(1.869em 1000em 3.629em -0.371em); top: -2.984em; left: 0.003em;"><span class="mo" id="MathJax-Span-122" style="font-family: STIXSizeOneSym; vertical-align: -0.531em;">∑</span><span style="display: inline-block; width: 0px; height: 2.989em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.669em -0.424em); top: -1.224em; left: 0.536em;"><span class="mi" id="MathJax-Span-123" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(1.816em 1000em 2.563em -0.477em); top: -3.571em; left: 0.483em;"><span class="mi" id="MathJax-Span-124" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mtext" id="MathJax-Span-125" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">Var</span><span class="mo" id="MathJax-Span-126" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-127"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-128" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-129" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-130" style="font-family: STIXGeneral-Regular;">)</span><span class="mtext" id="MathJax-Span-131" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-132" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-133"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-134" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-135" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-136" style="font-family: STIXGeneral-Regular;">)</span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(1.709em 1000em 2.883em -0.157em); top: 3.896em; left: 0.003em;"><span class="mtd" id="MathJax-Span-139"><span class="mrow" id="MathJax-Span-140"><span class="mo" id="MathJax-Span-141" style="font-family: STIXGeneral-Regular; padding-left: 0.269em; padding-right: 0.376em;">=</span><span class="mrow" id="MathJax-Span-142"><span class="mo" id="MathJax-Span-143" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-144" style="font-family: STIXGeneral-Italic;">n</span><span class="mtext" id="MathJax-Span-145" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-146" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-147" style="font-family: STIXGeneral-Italic;">w</span><span class="mo" id="MathJax-Span-148" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-149" style="font-family: STIXGeneral-Regular;">)</span></span><span class="mtext" id="MathJax-Span-150" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">Var</span><span class="mo" id="MathJax-Span-151" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-152" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-153" style="font-family: STIXGeneral-Regular;">)</span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="display: inline-block; width: 0px; height: 14.829em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 17.403em; vertical-align: -8.397em;"></span></span></nobr>
<script type="math/tex; mode=display" id="MathJax-Element-3"> \begin{align} \text{Var}(s) &= \text{Var}(\sum_i^n w_ix_i) \\ &= \sum_i^n \text{Var}(w_ix_i) \\ &= \sum_i^n [E(w_i)]^2\text{Var}(x_i) + E[(x_i)]^2\text{Var}(w_i) + \text{Var}(x_i)\text{Var}(w_i) \\ &= \sum_i^n \text{Var}(x_i)\text{Var}(w_i) \\ &= \left( n \text{Var}(w) \right) \text{Var}(x) \end{align} </script>

注意,这个推导的前2步用到了方差的性质。第3步我们假定输入均值为0,因此<nobr><span class="math" id="MathJax-Span-154" style="width: 9.336em; display: inline-block;"><span style="display: inline-block; position: relative; width: 7.469em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.883em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-155"><span class="mi" id="MathJax-Span-156" style="font-family: STIXGeneral-Italic;">E<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-157" style="font-family: STIXGeneral-Regular;">[</span><span class="msubsup" id="MathJax-Span-158"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-159" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-160" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-161" style="font-family: STIXGeneral-Regular;">]</span><span class="mo" id="MathJax-Span-162" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mi" id="MathJax-Span-163" style="font-family: STIXGeneral-Italic; padding-left: 0.323em;">E<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-164" style="font-family: STIXGeneral-Regular;">[</span><span class="msubsup" id="MathJax-Span-165"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-166" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-167" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-168" style="font-family: STIXGeneral-Regular;">]</span><span class="mo" id="MathJax-Span-169" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mn" id="MathJax-Span-170" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">0</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.203em; vertical-align: -0.263em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-4">E[x_i] = E[w_i] = 0</script>。不过这是我们的一个假设,实际情况下并不一定是这样的,比如ReLU单元的均值就是正的。最后一步我们假定<nobr><span class="math" id="MathJax-Span-171" style="width: 2.776em; display: inline-block;"><span style="display: inline-block; position: relative; width: 2.189em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.976em 1000em 2.883em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-172"><span class="msubsup" id="MathJax-Span-173"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-174" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.696em;"><span class="mi" id="MathJax-Span-175" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-176" style="font-family: STIXGeneral-Regular;">,</span><span class="msubsup" id="MathJax-Span-177" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.531em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-178" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mi" id="MathJax-Span-179" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 0.87em; vertical-align: -0.263em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-5">w_i,x_i</script>是独立分布。我们想让s的方差和输入x的方差一致,因此我们想让w的方差取值为1/n,又因为我们有公式<nobr><span class="math" id="MathJax-Span-180" style="width: 10.616em; display: inline-block;"><span style="display: inline-block; position: relative; width: 8.483em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.549em 1000em 2.883em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-181"><span class="mtext" id="MathJax-Span-182" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-183" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-184" style="font-family: STIXGeneral-Italic;">a</span><span class="mi" id="MathJax-Span-185" style="font-family: STIXGeneral-Italic;">X<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span class="mo" id="MathJax-Span-186" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-187" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="msubsup" id="MathJax-Span-188" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-189" style="font-family: STIXGeneral-Italic;">a</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.771em; left: 0.536em;"><span class="mn" id="MathJax-Span-190" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mtext" id="MathJax-Span-191" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-192" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-193" style="font-family: STIXGeneral-Italic;">X<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span class="mo" id="MathJax-Span-194" style="font-family: STIXGeneral-Regular;">)</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.403em; vertical-align: -0.263em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-6">\text{Var}(aX) = a^2\text{Var}(X)</script>,所以a应该取值为<nobr><span class="math" id="MathJax-Span-195" style="width: 4.909em; display: inline-block;"><span style="display: inline-block; position: relative; width: 3.896em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.603em 1000em 2.883em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-196"><span class="mi" id="MathJax-Span-197" style="font-family: STIXGeneral-Italic;">a</span><span class="mo" id="MathJax-Span-198" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="msqrt" id="MathJax-Span-199" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 2.029em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.371em); top: -2.557em; left: 0.749em;"><span class="mrow" id="MathJax-Span-200"><span class="mn" id="MathJax-Span-201" style="font-family: STIXGeneral-Regular;">1</span><span class="texatom" id="MathJax-Span-202"><span class="mrow" id="MathJax-Span-203"><span class="mo" id="MathJax-Span-204" style="font-family: STIXGeneral-Regular;">/</span></span></span><span class="mi" id="MathJax-Span-205" style="font-family: STIXGeneral-Italic;">n</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(3.043em 1000em 3.416em -0.477em); top: -3.997em; left: 0.749em;"><span style="display: inline-block; position: relative; width: 1.283em; height: 0px;"><span style="position: absolute; font-family: STIXGeneral-Regular; top: -3.997em; left: 0.003em;">‾<span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; font-family: STIXGeneral-Regular; top: -3.997em; left: 0.803em;">‾<span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="font-family: STIXGeneral-Regular; position: absolute; top: -3.997em; left: 0.376em;">‾<span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; clip: rect(2.883em 1000em 4.163em -0.424em); top: -3.891em; left: 0.003em;"><span style="font-family: STIXVariants;">√</span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.337em; vertical-align: -0.263em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-7">a = \sqrt{1/n}</script>,numpy里的实现为w = np.random.randn(n) / sqrt(n)

对于初始化权重还有一些类似的研究和建议,比如说Glorot在论文Understanding the difficulty of training deep feedforward neural networks就推荐使用能满足<nobr><span class="math" id="MathJax-Span-206" style="width: 11.896em; display: inline-block;"><span style="display: inline-block; position: relative; width: 9.496em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.883em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-207"><span class="mtext" id="MathJax-Span-208" style="font-family: STIXGeneral-Regular;">Var</span><span class="mo" id="MathJax-Span-209" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-210" style="font-family: STIXGeneral-Italic;">w</span><span class="mo" id="MathJax-Span-211" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-212" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mn" id="MathJax-Span-213" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">2</span><span class="texatom" id="MathJax-Span-214"><span class="mrow" id="MathJax-Span-215"><span class="mo" id="MathJax-Span-216" style="font-family: STIXGeneral-Regular;">/</span></span></span><span class="mo" id="MathJax-Span-217" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-218"><span style="display: inline-block; position: relative; width: 1.176em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-219" style="font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.536em;"><span class="texatom" id="MathJax-Span-220"><span class="mrow" id="MathJax-Span-221"><span class="mi" id="MathJax-Span-222" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span class="mi" id="MathJax-Span-223" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">n</span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-224" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="msubsup" id="MathJax-Span-225" style="padding-left: 0.269em;"><span style="display: inline-block; position: relative; width: 1.496em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-226" style="font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.536em;"><span class="texatom" id="MathJax-Span-227"><span class="mrow" id="MathJax-Span-228"><span class="mi" id="MathJax-Span-229" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">o</span><span class="mi" id="MathJax-Span-230" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">u</span><span class="mi" id="MathJax-Span-231" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">t<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-232" style="font-family: STIXGeneral-Regular;">)</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.203em; vertical-align: -0.263em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-8">\text{Var}(w) = 2/(n_{in} + n _{out})</script>的权重初始化。其中<nobr><span class="math" id="MathJax-Span-233" style="width: 3.949em; display: inline-block;"><span style="display: inline-block; position: relative; width: 3.149em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.976em 1000em 2.883em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-234"><span class="msubsup" id="MathJax-Span-235"><span style="display: inline-block; position: relative; width: 1.176em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-236" style="font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.536em;"><span class="texatom" id="MathJax-Span-237"><span class="mrow" id="MathJax-Span-238"><span class="mi" id="MathJax-Span-239" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span class="mi" id="MathJax-Span-240" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">n</span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-241" style="font-family: STIXGeneral-Regular;">,</span><span class="msubsup" id="MathJax-Span-242" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 1.496em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-243" style="font-family: STIXGeneral-Italic;">n</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.536em;"><span class="texatom" id="MathJax-Span-244"><span class="mrow" id="MathJax-Span-245"><span class="mi" id="MathJax-Span-246" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">o</span><span class="mi" id="MathJax-Span-247" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">u</span><span class="mi" id="MathJax-Span-248" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">t<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 0.87em; vertical-align: -0.263em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-9">n_{in}, n_{out}</script>是前一层和后一层的神经元个数。而另外一篇比较新的论文Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification,则指出尤其对于ReLU神经元,我们初始化方差应该为2.0/n,也就是w = np.random.randn(n) * sqrt(2.0/n)目前的神经网络中使用了很多ReLU单元,因此这个设定其实在实际应用中使用最多。

偏移量/bias初始化:相对而言,bias项初始化就简单一些。我们很多时候简单起见,直接就把它们都设为0.在ReLU单元中,有些同学会使用很小的数字(比如0.01)来代替0作为所有bias项的初始值,他们解释说这样也能保证ReLU单元一开始就是被激活的,因此反向传播过程中不会终止掉回传的梯度。不过似乎实际的实验过程中,这个优化并不是每次都能起到作用的,因此很多时候我们还是直接把bias项都初始化为0。

2.3 正则化

在前一节里我们说了我们要通过正则化来控制神经网络,使得它不那么容易过拟合。有几种正则化的类型供选择:

  • L2正则化,这个我们之前就提到过,非常常见。实现起来也很简单,我们在损失函数里,加入对每个参数的惩罚度。也就是说,对于每个权重<nobr><span class="math" id="MathJax-Span-249" style="width: 0.909em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.696em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-250"><span class="mi" id="MathJax-Span-251" style="font-family: STIXGeneral-Italic;">w</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 0.737em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-10">w</script>,我们在损失函数里加入一项<nobr><span class="math" id="MathJax-Span-252" style="width: 3.149em; display: inline-block;"><span style="display: inline-block; position: relative; width: 2.509em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.496em 1000em 3.096em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-253"><span class="mfrac" id="MathJax-Span-254"><span style="display: inline-block; position: relative; width: 0.483em; height: 0px; margin-right: 0.109em; margin-left: 0.109em;"><span style="position: absolute; clip: rect(1.763em 1000em 2.563em -0.424em); top: -2.824em; left: 50%; margin-left: -0.157em;"><span class="mn" id="MathJax-Span-255" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">1</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.563em -0.477em); top: -2.024em; left: 50%; margin-left: -0.157em;"><span class="mn" id="MathJax-Span-256" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(0.856em 1000em 1.229em -0.477em); top: -1.277em; left: 0.003em;"><span style="border-left-width: 0.483em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.25px; vertical-align: 0.003em;"></span><span style="display: inline-block; width: 0px; height: 1.069em;"></span></span></span></span><span class="mi" id="MathJax-Span-257" style="font-family: STIXGeneral-Italic; padding-left: 0.216em;">λ</span><span class="msubsup" id="MathJax-Span-258"><span style="display: inline-block; position: relative; width: 1.123em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-259" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.771em; left: 0.696em;"><span class="mn" id="MathJax-Span-260" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.737em; vertical-align: -0.53em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-11">\frac{1}{2} \lambda w^2</script>,其中<nobr><span class="math" id="MathJax-Span-261" style="width: 0.643em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.483em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-262"><span class="mi" id="MathJax-Span-263" style="font-family: STIXGeneral-Italic;">λ</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.003em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-12">\lambda</script>是我们可调整的正则化强度。顺便说一句,这里在前面加上1/2的原因是,求导/梯度的时候,刚好变成<nobr><span class="math" id="MathJax-Span-264" style="width: 1.496em; display: inline-block;"><span style="display: inline-block; position: relative; width: 1.176em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-265"><span class="mi" id="MathJax-Span-266" style="font-family: STIXGeneral-Italic;">λ</span><span class="mi" id="MathJax-Span-267" style="font-family: STIXGeneral-Italic;">w</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.003em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-13">\lambda w</script>而不是<nobr><span class="math" id="MathJax-Span-268" style="width: 2.083em; display: inline-block;"><span style="display: inline-block; position: relative; width: 1.656em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.424em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-269"><span class="mn" id="MathJax-Span-270" style="font-family: STIXGeneral-Regular;">2</span><span class="mi" id="MathJax-Span-271" style="font-family: STIXGeneral-Italic;">λ</span><span class="mi" id="MathJax-Span-272" style="font-family: STIXGeneral-Italic;">w</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.003em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-14">2\lambda w</script>。L2正则化理解起来也很简单,它对于特别大的权重有很高的惩罚度,以求让权重的分配均匀一些,而不是集中在某一小部分的维度上。我们再想想,加入L2正则化项,其实意味着,在梯度下降参数更新的时候,每个权重以W += -lambda*W的程度被拉向0。

  • L1正则化,这也是一种很常见的正则化形式。在L1正则化中,我们对于每个权重<nobr><span class="math" id="MathJax-Span-273" style="width: 0.909em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.696em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-274"><span class="mi" id="MathJax-Span-275" style="font-family: STIXGeneral-Italic;">w</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 0.737em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-15">w</script>的惩罚项为<nobr><span class="math" id="MathJax-Span-276" style="width: 1.976em; display: inline-block;"><span style="display: inline-block; position: relative; width: 1.549em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-277"><span class="mi" id="MathJax-Span-278" style="font-family: STIXGeneral-Italic;">λ</span><span class="texatom" id="MathJax-Span-279"><span class="mrow" id="MathJax-Span-280"><span class="mo" id="MathJax-Span-281" style="font-family: STIXGeneral-Regular;">|</span></span></span><span class="mi" id="MathJax-Span-282" style="font-family: STIXGeneral-Italic;">w</span><span class="texatom" id="MathJax-Span-283"><span class="mrow" id="MathJax-Span-284"><span class="mo" id="MathJax-Span-285" style="font-family: STIXGeneral-Regular;">|</span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.003em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-16">\lambda |w|</script>。有时候,你甚至可以看到大神们混着L1和L2正则化用,也就是说加入惩罚项<nobr><span class="math" id="MathJax-Span-286" style="width: 7.683em; display: inline-block;"><span style="display: inline-block; position: relative; width: 6.136em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.549em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-287"><span class="msubsup" id="MathJax-Span-288"><span style="display: inline-block; position: relative; width: 0.909em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-289" style="font-family: STIXGeneral-Italic;">λ</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mn" id="MathJax-Span-290" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">1</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-291" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">∣</span><span class="mi" id="MathJax-Span-292" style="font-family: STIXGeneral-Italic; padding-left: 0.323em;">w</span><span class="mo" id="MathJax-Span-293" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">∣</span><span class="mo" id="MathJax-Span-294" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">+</span><span class="msubsup" id="MathJax-Span-295"><span style="display: inline-block; position: relative; width: 0.909em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-296" style="font-family: STIXGeneral-Italic;">λ</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.483em;"><span class="mn" id="MathJax-Span-297" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-298"><span style="display: inline-block; position: relative; width: 1.123em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-299" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.771em; left: 0.696em;"><span class="mn" id="MathJax-Span-300" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.403em; vertical-align: -0.33em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-17">\lambda_1 \mid w \mid + \lambda_2 w^2</script>,L1正则化有其独特的特性,它会让模型训练过程中,权重特征向量逐渐地稀疏化,这意味着到最后,我们只留下了对结果影响最大的一部分权重,而其他不相关的输入(例如『噪声』)因为得不到权重被抑制。所以通常L2正则化后的特征向量是一组很分散的小值,而L1正则化只留下影响较大的权重。在实际应用中,如果你不是特别要求只保留部分特征,那么L2正则化通常能得到比L1正则化更好的效果

  • 最大范数约束,另外一种正则化叫做最大范数约束,它直接限制了一个上行的权重边界,然后约束每个神经元上的权重都要满足这个约束。实际应用中是这样实现的,我们不添加任何的惩罚项,就按照正常的损失函数计算,只不过在得到每个神经元的权重向量<nobr><span class="math" id="MathJax-Span-301" style="width: 0.963em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.389em 1000em 2.509em -0.797em); top: -2.344em; left: 0.003em;"><span class="mrow" id="MathJax-Span-302"><span class="texatom" id="MathJax-Span-303"><span class="mrow" id="MathJax-Span-304"><span class="munderover" id="MathJax-Span-305"><span style="display: inline-block; position: relative; width: 0.696em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-306" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -3.997em; left: 0.109em;"><span style="height: 0.003em; vertical-align: 0.003em; width: 0.429em; display: inline-block; overflow: hidden;"></span><span class="mo" id="MathJax-Span-307" style="font-family: STIXGeneral-Regular;">⃗<span style="height: 0.003em; vertical-align: 0.003em; margin-left: -0.264em;"></span></span><span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.137em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-18">\vec{w}</script>之后约束它满足<nobr><span class="math" id="MathJax-Span-308" style="width: 5.016em; display: inline-block;"><span style="display: inline-block; position: relative; width: 4.003em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.603em 1000em 2.936em -0.371em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-309"><span class="mo" id="MathJax-Span-310" style="font-family: STIXGeneral-Regular;">∥</span><span class="texatom" id="MathJax-Span-311"><span class="mrow" id="MathJax-Span-312"><span class="munderover" id="MathJax-Span-313"><span style="display: inline-block; position: relative; width: 0.696em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-314" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -3.997em; left: 0.109em;"><span style="height: 0.003em; vertical-align: 0.003em; width: 0.429em; display: inline-block; overflow: hidden;"></span><span class="mo" id="MathJax-Span-315" style="font-family: STIXGeneral-Regular;">⃗<span style="height: 0.003em; vertical-align: 0.003em; margin-left: -0.264em;"></span></span><span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span></span></span></span><span class="msubsup" id="MathJax-Span-316"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.371em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-317" style="font-family: STIXGeneral-Regular;">∥</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.184em; left: 0.536em;"><span class="mn" id="MathJax-Span-318" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-319" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">&lt;</span><span class="mi" id="MathJax-Span-320" style="font-family: STIXGeneral-Italic; padding-left: 0.323em;">c</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.403em; vertical-align: -0.33em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-19">\Vert \vec{w} \Vert_2 < c</script>。有些人提到这种正则化方式帮助他们提高最后的模型效果。另外,这种正则化方式倒是有一点很吸引人:在神经网络训练学习率设定很高的时候,它也能很好地约束住权重更新变化,不至于直接挂掉。

  • Dropout,亲,这个是我们实际神经网络训练中,用的非常多的一种正则化手段,同时也相当有效。Srivastava等人的论文Dropout: A Simple Way to Prevent Neural Networks from Overfitting最早提到用dropout这种方式作为正则化手段。一句话概括它,就是:在训练过程中,我们对每个神经元,都以概率p保持它是激活状态,1-p的概率直接关闭它。

下图是一个3层的神经网络的dropout示意图:


Dropout示意图

可以这么理解,在训练过程中呢,我们对全体神经元,以概率p做了一个采样,只有选出的神经元要进行参数更新。所以最后就从左图的全连接到右图的Dropout过后神经元连接图了。需要多说一句的是,在测试阶段,我们不用dropout,而是直接从概率的角度,对权重配以一个概率值。

简单的Dropout代码如下(这是简易实现版本,但是不建议使用,我们会分析为啥,并在之后给出优化版):


p = 0.5 # 设定dropout的概率,也就是保持一个神经元激活状态的概率

def train_step(X):
  """ X contains the data """

  # 3层神经网络前向计算
  H1 = np.maximum(0, np.dot(W1, X) + b1)
  U1 = np.random.rand(*H1.shape) < p # 第一次Dropout
  H1 *= U1 # drop!
  H2 = np.maximum(0, np.dot(W2, H1) + b2)
  U2 = np.random.rand(*H2.shape) < p # 第二次Dropout
  H2 *= U2 # drop!
  out = np.dot(W3, H2) + b3

  # 反向传播: 计算梯度... (这里省略)
  # 参数更新... (这里省略)

def predict(X):
  # 加上Dropout之后的前向计算
  H1 = np.maximum(0, np.dot(W1, X) + b1) * p 
  H2 = np.maximum(0, np.dot(W2, H1) + b2) * p 
  out = np.dot(W3, H2) + b3

上述代码中,在train_step函数中,我们做了2次Dropout。我们甚至可以在输入层做一次dropout。反向传播过程保持不变,除了我们要考虑一下U1,U2

很重要的一点是,大家仔细看predict函数部分,我们不再dropout了,而是对于每个隐层的输出,都用概率p做了一个幅度变换。可以从数学期望的角度去理解这个做法,我们考虑一个神经元的输出为x(没有dropout的情况下),它的输出的数学期望为<nobr><span class="math" id="MathJax-Span-321" style="width: 7.043em; display: inline-block;"><span style="display: inline-block; position: relative; width: 5.603em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.531em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-322"><span class="mi" id="MathJax-Span-323" style="font-family: STIXGeneral-Italic;">p</span><span class="mi" id="MathJax-Span-324" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-325" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="mo" id="MathJax-Span-326" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">(</span><span class="mn" id="MathJax-Span-327" style="font-family: STIXGeneral-Regular;">1</span><span class="mo" id="MathJax-Span-328" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="mi" id="MathJax-Span-329" style="font-family: STIXGeneral-Italic; padding-left: 0.269em;">p</span><span class="mo" id="MathJax-Span-330" style="font-family: STIXGeneral-Regular;">)</span><span class="mn" id="MathJax-Span-331" style="font-family: STIXGeneral-Regular;">0</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.27em; vertical-align: -0.33em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-20">px + (1-p)0</script>,那我们在测试阶段,如果直接把每个输出x都做变换<nobr><span class="math" id="MathJax-Span-332" style="width: 3.736em; display: inline-block;"><span style="display: inline-block; position: relative; width: 2.989em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.531em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-333"><span class="mi" id="MathJax-Span-334" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-335" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">→</span><span class="mi" id="MathJax-Span-336" style="font-family: STIXGeneral-Italic; padding-left: 0.323em;">p</span><span class="mi" id="MathJax-Span-337" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 0.937em; vertical-align: -0.33em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-21">x \rightarrow px</script>,其实是可以保持一样的数学期望的。

上述代码的写法有一些缺陷,我们必须在测试阶段对每个神经的输出都以p的概率输出。考虑到实际应用中,测试阶段对于时间的要求非常高,我们可以考虑反着来,代码实现的时候用inverted dropout,即在训练阶段就做相反的幅度变换/scaling(除以p),这样在测试阶段,我们可以直接把权重拿来使用,而不用附加很多步用p做scaling的过程。inverted dropout的示例代码如下:

""" 
Inverted Dropout的版本,把本该花在测试阶段的时间,转移到训练阶段,从而提高testing部分的速度
"""

p = 0.5 # dropout的概率,也就是保持一个神经元激活状态的概率

def train_step(X):
  # f3层神经网络前向计算
  H1 = np.maximum(0, np.dot(W1, X) + b1)
  U1 = (np.random.rand(*H1.shape) < p) / p # 注意到这个dropout中我们除以p,做了一个inverted dropout
  H1 *= U1 # drop!
  H2 = np.maximum(0, np.dot(W2, H1) + b2)
  U2 = (np.random.rand(*H2.shape) < p) / p # 这个dropout中我们除以p,做了一个inverted dropout
  H2 *= U2 # drop!
  out = np.dot(W3, H2) + b3

  # 反向传播: 计算梯度... (这里省略)
  # 参数更新... (这里省略)

def predict(X):
  # 直接前向计算,无需再乘以p
  H1 = np.maximum(0, np.dot(W1, X) + b1) 
  H2 = np.maximum(0, np.dot(W2, H1) + b2)
  out = np.dot(W3, H2) + b3

对于dropout这个部分如果你有更深的兴趣,欢迎阅读以下文献:
* 2014 Srivastava 的论文Dropout paper
* Dropout Training as Adaptive Regularization

  • bias项的正则化,其实我们在之前的博客中提到过,我们大部分时候并不对偏移量项做正则化,因为它们也没有和数据直接有乘法等交互,也就自然不会影响到最后结果中某个数据维度的作用。不过如果你愿意对它做正则化,倒也不会影响最后结果,毕竟总共有那么多权重项,才那么些bias项,所以一般也不会影响结果。

实际应用中:我们最常见到的是,在全部的交叉验证集上使用L2正则化,同时我们在每一层之后用dropout,很常见的dropout概率为p=0.5,你也可以通过交叉验证去调整这个值。

2.4 损失函数

刚才讨论了数据预处理、权重初始化与正则化相关的问题。现在我们回到训练需要的关键之一:损失函数。对于这么复杂的神经网络,我们也得有一个评估准则去评估预测值和真实结果之间的吻合度,也就是损失函数。神经网络里的损失函数,实际上是计算出了每个样本上的loss,再求平均之后的一个形式,即<nobr><span class="math" id="MathJax-Span-338" style="width: 6.616em; display: inline-block;"><span style="display: inline-block; position: relative; width: 5.283em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.496em 1000em 3.096em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-339"><span class="mi" id="MathJax-Span-340" style="font-family: STIXGeneral-Italic;">L<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-341" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mfrac" id="MathJax-Span-342" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 0.643em; height: 0px; margin-right: 0.109em; margin-left: 0.109em;"><span style="position: absolute; clip: rect(1.763em 1000em 2.563em -0.424em); top: -2.824em; left: 50%; margin-left: -0.157em;"><span class="mn" id="MathJax-Span-343" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">1</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.563em -0.477em); top: -2.024em; left: 50%; margin-left: -0.264em;"><span class="mi" id="MathJax-Span-344" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">N<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(0.856em 1000em 1.229em -0.477em); top: -1.277em; left: 0.003em;"><span style="border-left-width: 0.643em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.25px; vertical-align: 0.003em;"></span><span style="display: inline-block; width: 0px; height: 1.069em;"></span></span></span></span><span class="munderover" id="MathJax-Span-345" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 1.229em; height: 0px;"><span style="position: absolute; clip: rect(1.656em 1000em 2.989em -0.424em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-346" style="font-family: STIXGeneral-Regular; vertical-align: 0.003em;">∑</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.077em; left: 0.963em;"><span class="mi" id="MathJax-Span-347" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-348" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 0.856em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-349" style="font-family: STIXGeneral-Italic;">L<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.589em;"><span class="mi" id="MathJax-Span-350" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.737em; vertical-align: -0.53em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-22">L = \frac{1}{N} \sum_i L_i</script>,其中N是训练样本数。

2.4.1 分类问题

  • 分类问题是到目前为止我们一直在讨论的。我们假定一个数据集中每个样本都有唯一一个正确的标签/类别。我们之前提到过有两种损失函数可以使用,其一是SVM的hinge loss:

<nobr><span class="math" id="MathJax-Span-351" style="width: 14.616em; display: inline-block;"><span style="display: inline-block; position: relative; width: 11.683em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.443em 1000em 4.163em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-352"><span class="msubsup" id="MathJax-Span-353"><span style="display: inline-block; position: relative; width: 0.856em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-354" style="font-family: STIXGeneral-Italic;">L<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.589em;"><span class="mi" id="MathJax-Span-355" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-356" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="munderover" id="MathJax-Span-357" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 1.336em; height: 0px;"><span style="position: absolute; clip: rect(1.869em 1000em 3.629em -0.371em); top: -2.984em; left: 0.003em;"><span class="mo" id="MathJax-Span-358" style="font-family: STIXSizeOneSym; vertical-align: -0.531em;">∑</span><span style="display: inline-block; width: 0px; height: 2.989em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.829em -0.584em); top: -1.224em; left: 0.003em;"><span class="texatom" id="MathJax-Span-359"><span class="mrow" id="MathJax-Span-360"><span class="mi" id="MathJax-Span-361" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-362" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">≠</span><span class="msubsup" id="MathJax-Span-363"><span style="display: inline-block; position: relative; width: 0.589em; height: 0px;"><span style="position: absolute; clip: rect(1.923em 1000em 2.723em -0.477em); top: -2.397em; left: 0.003em;"><span class="mi" id="MathJax-Span-364" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; top: -2.184em; left: 0.323em;"><span class="mi" id="MathJax-Span-365" style="font-size: 50%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-366" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">max</span><span class="mo" id="MathJax-Span-367" style="font-family: STIXGeneral-Regular;">(</span><span class="mn" id="MathJax-Span-368" style="font-family: STIXGeneral-Regular;">0</span><span class="mo" id="MathJax-Span-369" style="font-family: STIXGeneral-Regular;">,</span><span class="msubsup" id="MathJax-Span-370" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 0.643em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.637em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-371" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.323em;"><span class="mi" id="MathJax-Span-372" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-373" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="msubsup" id="MathJax-Span-374" style="padding-left: 0.269em;"><span style="display: inline-block; position: relative; width: 1.016em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.637em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-375" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -3.731em; left: 0.323em;"><span class="texatom" id="MathJax-Span-376"><span class="mrow" id="MathJax-Span-377"><span class="msubsup" id="MathJax-Span-378"><span style="display: inline-block; position: relative; width: 0.589em; height: 0px;"><span style="position: absolute; clip: rect(1.923em 1000em 2.723em -0.477em); top: -2.397em; left: 0.003em;"><span class="mi" id="MathJax-Span-379" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; top: -2.184em; left: 0.323em;"><span class="mi" id="MathJax-Span-380" style="font-size: 50%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span></span><span class="mo" id="MathJax-Span-381" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="mn" id="MathJax-Span-382" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">1</span><span class="mo" id="MathJax-Span-383" style="font-family: STIXGeneral-Regular;">)</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 3.137em; vertical-align: -1.863em;"></span></span></nobr>
<script type="math/tex; mode=display" id="MathJax-Element-23">L_i = \sum_{j\neq y_i} \max(0, f_j - f_{y_i} + 1)</script>

另外一个是Softmax分类器中用到的互熵损失:

<nobr><span class="math" id="MathJax-Span-384" style="width: 11.363em; display: inline-block;"><span style="display: inline-block; position: relative; width: 9.069em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(2.189em 1000em 5.603em -0.477em); top: -4.157em; left: 0.003em;"><span class="mrow" id="MathJax-Span-385"><span class="msubsup" id="MathJax-Span-386"><span style="display: inline-block; position: relative; width: 0.856em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-387" style="font-family: STIXGeneral-Italic;">L<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.589em;"><span class="mi" id="MathJax-Span-388" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-389" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mo" id="MathJax-Span-390" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">−</span><span class="mi" id="MathJax-Span-391" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">log</span><span class="mrow" id="MathJax-Span-392" style="padding-left: 0.216em;"><span class="mo" id="MathJax-Span-393" style="vertical-align: -0.797em;"><span style="font-family: STIXSizeFourSym;">(</span></span><span class="mfrac" id="MathJax-Span-394"><span style="display: inline-block; position: relative; width: 2.616em; height: 0px; margin-right: 0.109em; margin-left: 0.109em;"><span style="position: absolute; clip: rect(2.936em 1000em 4.163em -0.424em); top: -4.691em; left: 50%; margin-left: -0.691em;"><span class="msubsup" id="MathJax-Span-395"><span style="display: inline-block; position: relative; width: 1.389em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.424em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-396" style="font-family: STIXGeneral-Italic;">e</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -4.424em; left: 0.483em;"><span class="texatom" id="MathJax-Span-397"><span class="mrow" id="MathJax-Span-398"><span class="msubsup" id="MathJax-Span-399"><span style="display: inline-block; position: relative; width: 0.856em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.584em); top: -2.397em; left: 0.003em;"><span class="mi" id="MathJax-Span-400" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.109em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; top: -3.837em; left: 0.216em;"><span class="texatom" id="MathJax-Span-401"><span class="mrow" id="MathJax-Span-402"><span class="msubsup" id="MathJax-Span-403"><span style="display: inline-block; position: relative; width: 0.589em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.616em -0.477em); top: -2.344em; left: 0.003em;"><span class="mi" id="MathJax-Span-404" style="font-size: 50%; font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span><span style="position: absolute; top: -2.184em; left: 0.323em;"><span class="mi" id="MathJax-Span-405" style="font-size: 50%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; clip: rect(2.989em 1000em 4.589em -0.424em); top: -3.144em; left: 50%; margin-left: -1.224em;"><span class="mrow" id="MathJax-Span-406"><span class="munderover" id="MathJax-Span-407"><span style="display: inline-block; position: relative; width: 1.229em; height: 0px;"><span style="position: absolute; clip: rect(1.656em 1000em 2.989em -0.424em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-408" style="font-family: STIXGeneral-Regular; vertical-align: 0.003em;">∑</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.077em; left: 0.963em;"><span class="mi" id="MathJax-Span-409" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-410" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 1.016em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.424em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-411" style="font-family: STIXGeneral-Italic;">e</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -4.371em; left: 0.483em;"><span class="texatom" id="MathJax-Span-412"><span class="mrow" id="MathJax-Span-413"><span class="msubsup" id="MathJax-Span-414"><span style="display: inline-block; position: relative; width: 0.483em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.584em); top: -2.397em; left: 0.003em;"><span class="mi" id="MathJax-Span-415" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.109em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; top: -2.184em; left: 0.216em;"><span class="mi" id="MathJax-Span-416" style="font-size: 50%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; clip: rect(0.856em 1000em 1.229em -0.477em); top: -1.277em; left: 0.003em;"><span style="border-left-width: 2.616em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.25px; vertical-align: 0.003em;"></span><span style="display: inline-block; width: 0px; height: 1.069em;"></span></span></span></span><span class="mo" id="MathJax-Span-417" style="vertical-align: -0.797em;"><span style="font-family: STIXSizeFourSym;">)</span></span></span></span><span style="display: inline-block; width: 0px; height: 4.163em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 4.003em; vertical-align: -1.663em;"></span></span></nobr>
<script type="math/tex; mode=display" id="MathJax-Element-24">L_i = -\log\left(\frac{e^{f_{y_i}}}{ \sum_j e^{f_j} }\right)</script>
  • 问题:特别多的类别数。当类别标签特别特别多的时候(比如ImageNet包含22000个类别),层次化的Softmax,它将类别标签建成了一棵树,这样任何一个类别,其实就对应tree的一条路径,然后我们在每个树的结点上都训练一个Softmax以区分是左分支还是右分支。

  • 属性分类,上述的两种损失函数都假定,对于每个样本,我们只有一个正确的答案<nobr><span class="math" id="MathJax-Span-418" style="width: 1.016em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.803em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.763em 1000em 2.776em -0.477em); top: -2.344em; left: 0.003em;"><span class="mrow" id="MathJax-Span-419"><span class="msubsup" id="MathJax-Span-420"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-421" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="mi" id="MathJax-Span-422" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.003em; vertical-align: -0.397em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-25">y_i</script>。但是在有些场景下,<nobr><span class="math" id="MathJax-Span-423" style="width: 1.016em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.803em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.763em 1000em 2.776em -0.477em); top: -2.344em; left: 0.003em;"><span class="mrow" id="MathJax-Span-424"><span class="msubsup" id="MathJax-Span-425"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-426" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="mi" id="MathJax-Span-427" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.003em; vertical-align: -0.397em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-26">y_i</script>是一个二值的向量,每个元素都代表有没有某个属性,这时候我们怎么办呢?举个例子说,Instagram上的图片可以看作一大堆hashtag里的一个tag子集,所有一张图片可以有多个tag。对于这种情况,大家可能会想到一个最简单的处理方法,就是对每个属性值都建一个二分类的分类器。比如,对应某个类别的二分类器可能有如下形式的损失函数:

<nobr><span class="math" id="MathJax-Span-428" style="width: 13.069em; display: inline-block;"><span style="display: inline-block; position: relative; width: 10.456em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.443em 1000em 4.109em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-429"><span class="msubsup" id="MathJax-Span-430"><span style="display: inline-block; position: relative; width: 0.856em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-431" style="font-family: STIXGeneral-Italic;">L<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.589em;"><span class="mi" id="MathJax-Span-432" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-433" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="munderover" id="MathJax-Span-434" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 1.336em; height: 0px;"><span style="position: absolute; clip: rect(1.869em 1000em 3.629em -0.371em); top: -2.984em; left: 0.003em;"><span class="mo" id="MathJax-Span-435" style="font-family: STIXSizeOneSym; vertical-align: -0.531em;">∑</span><span style="display: inline-block; width: 0px; height: 2.989em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.829em -0.584em); top: -1.224em; left: 0.536em;"><span class="mi" id="MathJax-Span-436" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-437" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">max</span><span class="mo" id="MathJax-Span-438" style="font-family: STIXGeneral-Regular;">(</span><span class="mn" id="MathJax-Span-439" style="font-family: STIXGeneral-Regular;">0</span><span class="mo" id="MathJax-Span-440" style="font-family: STIXGeneral-Regular;">,</span><span class="mn" id="MathJax-Span-441" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">1</span><span class="mo" id="MathJax-Span-442" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="msubsup" id="MathJax-Span-443" style="padding-left: 0.269em;"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-444" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="texatom" id="MathJax-Span-445"><span class="mrow" id="MathJax-Span-446"><span class="mi" id="MathJax-Span-447" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span class="mi" id="MathJax-Span-448" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-449"><span style="display: inline-block; position: relative; width: 0.643em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.637em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-450" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.323em;"><span class="mi" id="MathJax-Span-451" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-452" style="font-family: STIXGeneral-Regular;">)</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 3.07em; vertical-align: -1.797em;"></span></span></nobr>
<script type="math/tex; mode=display" id="MathJax-Element-27">L_i = \sum_j \max(0, 1 - y_{ij} f_j)</script>

其中的求和是针对有所的类别j,而<nobr><span class="math" id="MathJax-Span-453" style="width: 1.283em; display: inline-block;"><span style="display: inline-block; position: relative; width: 1.016em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.763em 1000em 2.883em -0.477em); top: -2.344em; left: 0.003em;"><span class="mrow" id="MathJax-Span-454"><span class="msubsup" id="MathJax-Span-455"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-456" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="texatom" id="MathJax-Span-457"><span class="mrow" id="MathJax-Span-458"><span class="mi" id="MathJax-Span-459" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span class="mi" id="MathJax-Span-460" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.203em; vertical-align: -0.53em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-28">y_{ij}</script>是1或者-1(取决于第i个样本是否有第j个属性的标签),打分向量<nobr><span class="math" id="MathJax-Span-461" style="width: 0.909em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.696em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.496em 1000em 2.883em -0.637em); top: -2.344em; left: 0.003em;"><span class="mrow" id="MathJax-Span-462"><span class="msubsup" id="MathJax-Span-463"><span style="display: inline-block; position: relative; width: 0.643em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.637em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-464" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.323em;"><span class="mi" id="MathJax-Span-465" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.47em; vertical-align: -0.53em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-29">f_j</script>在类别/标签被预测到的情况下为正,其他情况为负。注意到如果正样本有比+1小的得分,或者负样本有比-1大的得分,那么损失/loss就一直在累积。

另外一个也许有效的解决办法是,我们可以对每个属性,都单独训练一个逻辑回归分类器,一个二分类的逻辑回归分类器只有0,1两个类别,属于1的概率为:

<nobr><span class="math" id="MathJax-Span-466" style="width: 25.016em; display: inline-block;"><span style="display: inline-block; position: relative; width: 20.003em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.069em 1000em 3.576em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-467"><span class="mi" id="MathJax-Span-468" style="font-family: STIXGeneral-Italic;">P</span><span class="mo" id="MathJax-Span-469" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-470" style="font-family: STIXGeneral-Italic;">y</span><span class="mo" id="MathJax-Span-471" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mn" id="MathJax-Span-472" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">1</span><span class="mo" id="MathJax-Span-473" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">∣</span><span class="mi" id="MathJax-Span-474" style="font-family: STIXGeneral-Italic; padding-left: 0.323em;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-475" style="font-family: STIXGeneral-Regular;">;</span><span class="mi" id="MathJax-Span-476" style="font-family: STIXGeneral-Italic; padding-left: 0.216em;">w</span><span class="mo" id="MathJax-Span-477" style="font-family: STIXGeneral-Regular;">,</span><span class="mi" id="MathJax-Span-478" style="font-family: STIXGeneral-Italic; padding-left: 0.216em;">b</span><span class="mo" id="MathJax-Span-479" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-480" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mfrac" id="MathJax-Span-481" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 5.496em; height: 0px; margin-right: 0.109em; margin-left: 0.109em;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.371em); top: -3.251em; left: 50%; margin-left: -0.264em;"><span class="mn" id="MathJax-Span-482" style="font-family: STIXGeneral-Regular;">1</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(1.603em 1000em 2.776em -0.371em); top: -1.757em; left: 50%; margin-left: -2.664em;"><span class="mrow" id="MathJax-Span-483"><span class="mn" id="MathJax-Span-484" style="font-family: STIXGeneral-Regular;">1</span><span class="mo" id="MathJax-Span-485" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="msubsup" id="MathJax-Span-486" style="padding-left: 0.269em;"><span style="display: inline-block; position: relative; width: 3.629em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.424em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-487" style="font-family: STIXGeneral-Italic;">e</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.664em; left: 0.483em;"><span class="texatom" id="MathJax-Span-488"><span class="mrow" id="MathJax-Span-489"><span class="mo" id="MathJax-Span-490" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">−</span><span class="mo" id="MathJax-Span-491" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-492"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.923em 1000em 2.563em -0.477em); top: -2.397em; left: 0.003em;"><span class="mi" id="MathJax-Span-493" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; top: -2.557em; left: 0.483em;"><span class="mi" id="MathJax-Span-494" style="font-size: 50%; font-family: STIXGeneral-Italic;">T<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span></span><span class="mi" id="MathJax-Span-495" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-496" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">+</span><span class="mi" id="MathJax-Span-497" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">b</span><span class="mo" id="MathJax-Span-498" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">)</span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(0.856em 1000em 1.229em -0.477em); top: -1.277em; left: 0.003em;"><span style="border-left-width: 5.496em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.25px; vertical-align: 0.003em;"></span><span style="display: inline-block; width: 0px; height: 1.069em;"></span></span></span></span><span class="mo" id="MathJax-Span-499" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mi" id="MathJax-Span-500" style="font-family: STIXGeneral-Italic; padding-left: 0.323em;">σ<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span class="mo" id="MathJax-Span-501" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-502"><span style="display: inline-block; position: relative; width: 1.229em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-503" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.824em; left: 0.696em;"><span class="mi" id="MathJax-Span-504" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">T<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mi" id="MathJax-Span-505" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-506" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="mi" id="MathJax-Span-507" style="font-family: STIXGeneral-Italic; padding-left: 0.269em;">b</span><span class="mo" id="MathJax-Span-508" style="font-family: STIXGeneral-Regular;">)</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 2.87em; vertical-align: -1.13em;"></span></span></nobr>
<script type="math/tex; mode=display" id="MathJax-Element-30">P(y = 1 \mid x; w, b) = \frac{1}{1 + e^{-(w^Tx +b)}} = \sigma (w^Tx + b)</script>

又因为0,1两类的概率和为1,所以归属于类别0的概率为<nobr><span class="math" id="MathJax-Span-509" style="width: 21.443em; display: inline-block;"><span style="display: inline-block; position: relative; width: 17.123em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-510"><span class="mi" id="MathJax-Span-511" style="font-family: STIXGeneral-Italic;">P</span><span class="mo" id="MathJax-Span-512" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-513" style="font-family: STIXGeneral-Italic;">y</span><span class="mo" id="MathJax-Span-514" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mn" id="MathJax-Span-515" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">0</span><span class="mo" id="MathJax-Span-516" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">∣</span><span class="mi" id="MathJax-Span-517" style="font-family: STIXGeneral-Italic; padding-left: 0.323em;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-518" style="font-family: STIXGeneral-Regular;">;</span><span class="mi" id="MathJax-Span-519" style="font-family: STIXGeneral-Italic; padding-left: 0.216em;">w</span><span class="mo" id="MathJax-Span-520" style="font-family: STIXGeneral-Regular;">,</span><span class="mi" id="MathJax-Span-521" style="font-family: STIXGeneral-Italic; padding-left: 0.216em;">b</span><span class="mo" id="MathJax-Span-522" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-523" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mn" id="MathJax-Span-524" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">1</span><span class="mo" id="MathJax-Span-525" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="mi" id="MathJax-Span-526" style="font-family: STIXGeneral-Italic; padding-left: 0.269em;">P</span><span class="mo" id="MathJax-Span-527" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-528" style="font-family: STIXGeneral-Italic;">y</span><span class="mo" id="MathJax-Span-529" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mn" id="MathJax-Span-530" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">1</span><span class="mo" id="MathJax-Span-531" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">∣</span><span class="mi" id="MathJax-Span-532" style="font-family: STIXGeneral-Italic; padding-left: 0.323em;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-533" style="font-family: STIXGeneral-Regular;">;</span><span class="mi" id="MathJax-Span-534" style="font-family: STIXGeneral-Italic; padding-left: 0.216em;">w</span><span class="mo" id="MathJax-Span-535" style="font-family: STIXGeneral-Regular;">,</span><span class="mi" id="MathJax-Span-536" style="font-family: STIXGeneral-Italic; padding-left: 0.216em;">b</span><span class="mo" id="MathJax-Span-537" style="font-family: STIXGeneral-Regular;">)</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.27em; vertical-align: -0.33em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-31">P(y = 0 \mid x; w, b) = 1 - P(y = 1 \mid x; w,b)</script>。一个样本在<nobr><span class="math" id="MathJax-Span-538" style="width: 9.016em; display: inline-block;"><span style="display: inline-block; position: relative; width: 7.203em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.603em 1000em 2.883em -0.424em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-539"><span class="mi" id="MathJax-Span-540" style="font-family: STIXGeneral-Italic;">σ<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span class="mo" id="MathJax-Span-541" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-542"><span style="display: inline-block; position: relative; width: 1.229em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-543" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.771em; left: 0.696em;"><span class="mi" id="MathJax-Span-544" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">T<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mi" id="MathJax-Span-545" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-546" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="mi" id="MathJax-Span-547" style="font-family: STIXGeneral-Italic; padding-left: 0.269em;">b</span><span class="mo" id="MathJax-Span-548" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-549" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">&gt;</span><span class="mn" id="MathJax-Span-550" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">0.5</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.403em; vertical-align: -0.263em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-32">\sigma (w^Tx + b) > 0.5</script>的情况下被判定为1,对应sigmoid函数化简一下,对应的是得分<nobr><span class="math" id="MathJax-Span-551" style="width: 6.563em; display: inline-block;"><span style="display: inline-block; position: relative; width: 5.229em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.603em 1000em 2.776em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-552"><span class="msubsup" id="MathJax-Span-553"><span style="display: inline-block; position: relative; width: 1.229em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-554" style="font-family: STIXGeneral-Italic;">w</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.771em; left: 0.696em;"><span class="mi" id="MathJax-Span-555" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">T<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mi" id="MathJax-Span-556" style="font-family: STIXGeneral-Italic;">x<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span class="mo" id="MathJax-Span-557" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="mi" id="MathJax-Span-558" style="font-family: STIXGeneral-Italic; padding-left: 0.269em;">b</span><span class="mo" id="MathJax-Span-559" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">&gt;</span><span class="mn" id="MathJax-Span-560" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">0</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.203em; vertical-align: -0.13em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-33">w^Tx +b > 0</script>。这时候的损失函数可以定义为最大化似然概率的形式,也就是:

<nobr><span class="math" id="MathJax-Span-561" style="width: 23.843em; display: inline-block;"><span style="display: inline-block; position: relative; width: 19.043em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.443em 1000em 4.109em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-562"><span class="msubsup" id="MathJax-Span-563"><span style="display: inline-block; position: relative; width: 0.856em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-564" style="font-family: STIXGeneral-Italic;">L<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.589em;"><span class="mi" id="MathJax-Span-565" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-566" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="munderover" id="MathJax-Span-567" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 1.336em; height: 0px;"><span style="position: absolute; clip: rect(1.869em 1000em 3.629em -0.371em); top: -2.984em; left: 0.003em;"><span class="mo" id="MathJax-Span-568" style="font-family: STIXSizeOneSym; vertical-align: -0.531em;">∑</span><span style="display: inline-block; width: 0px; height: 2.989em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.829em -0.584em); top: -1.224em; left: 0.536em;"><span class="mi" id="MathJax-Span-569" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-570" style="padding-left: 0.216em;"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-571" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="texatom" id="MathJax-Span-572"><span class="mrow" id="MathJax-Span-573"><span class="mi" id="MathJax-Span-574" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span class="mi" id="MathJax-Span-575" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mi" id="MathJax-Span-576" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">log</span><span class="mo" id="MathJax-Span-577"></span><span class="mo" id="MathJax-Span-578" style="font-family: STIXGeneral-Regular;">(</span><span class="mi" id="MathJax-Span-579" style="font-family: STIXGeneral-Italic;">σ<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span class="mo" id="MathJax-Span-580" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-581"><span style="display: inline-block; position: relative; width: 0.643em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.637em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-582" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.323em;"><span class="mi" id="MathJax-Span-583" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-584" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-585" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-586" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">+</span><span class="mo" id="MathJax-Span-587" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">(</span><span class="mn" id="MathJax-Span-588" style="font-family: STIXGeneral-Regular;">1</span><span class="mo" id="MathJax-Span-589" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="msubsup" id="MathJax-Span-590" style="padding-left: 0.269em;"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-591" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="texatom" id="MathJax-Span-592"><span class="mrow" id="MathJax-Span-593"><span class="mi" id="MathJax-Span-594" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span class="mi" id="MathJax-Span-595" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-596" style="font-family: STIXGeneral-Regular;">)</span><span class="mi" id="MathJax-Span-597" style="font-family: STIXGeneral-Regular; padding-left: 0.216em;">log</span><span class="mo" id="MathJax-Span-598"></span><span class="mo" id="MathJax-Span-599" style="font-family: STIXGeneral-Regular;">(</span><span class="mn" id="MathJax-Span-600" style="font-family: STIXGeneral-Regular;">1</span><span class="mo" id="MathJax-Span-601" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="mi" id="MathJax-Span-602" style="font-family: STIXGeneral-Italic; padding-left: 0.269em;">σ<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.056em;"></span></span><span class="mo" id="MathJax-Span-603" style="font-family: STIXGeneral-Regular;">(</span><span class="msubsup" id="MathJax-Span-604"><span style="display: inline-block; position: relative; width: 0.643em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.637em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-605" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.323em;"><span class="mi" id="MathJax-Span-606" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-607" style="font-family: STIXGeneral-Regular;">)</span><span class="mo" id="MathJax-Span-608" style="font-family: STIXGeneral-Regular;">)</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 3.07em; vertical-align: -1.797em;"></span></span></nobr>
<script type="math/tex; mode=display" id="MathJax-Element-34">L_i = \sum_j y_{ij} \log(\sigma(f_j)) + (1 - y_{ij}) \log(1 - \sigma(f_j))</script>

其中标签<nobr><span class="math" id="MathJax-Span-609" style="width: 1.283em; display: inline-block;"><span style="display: inline-block; position: relative; width: 1.016em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.763em 1000em 2.883em -0.477em); top: -2.344em; left: 0.003em;"><span class="mrow" id="MathJax-Span-610"><span class="msubsup" id="MathJax-Span-611"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-612" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="texatom" id="MathJax-Span-613"><span class="mrow" id="MathJax-Span-614"><span class="mi" id="MathJax-Span-615" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span class="mi" id="MathJax-Span-616" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.203em; vertical-align: -0.53em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-35">y_{ij}</script>为1(正样本)或者0(负样本),而<nobr><span class="math" id="MathJax-Span-617" style="width: 0.643em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.483em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-618"><span class="mi" id="MathJax-Span-619" style="font-family: STIXGeneral-Italic;">δ</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.003em; vertical-align: -0.063em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-36">\delta</script>是sigmoid函数。

2.4.2 回归问题

回归是另外一类机器学习问题,主要用于预测连续值属性,比如房子的价格或者图像中某些东西的长度等。对于回归问题,我们一般计算预测值和实际值之间的差值,然后再求L2范数或者L1范数用于衡量。其中对一个样本(一张图片)计算的L2范数损失为:

<nobr><span class="math" id="MathJax-Span-620" style="width: 7.629em; display: inline-block;"><span style="display: inline-block; position: relative; width: 6.083em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.496em 1000em 2.989em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-621"><span class="msubsup" id="MathJax-Span-622"><span style="display: inline-block; position: relative; width: 0.856em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-623" style="font-family: STIXGeneral-Italic;">L<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.589em;"><span class="mi" id="MathJax-Span-624" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-625" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mo" id="MathJax-Span-626" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">∥</span><span class="mi" id="MathJax-Span-627" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span class="mo" id="MathJax-Span-628" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="msubsup" id="MathJax-Span-629" style="padding-left: 0.269em;"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-630" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="mi" id="MathJax-Span-631" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-632"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.371em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-633" style="font-family: STIXGeneral-Regular;">∥</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.563em -0.477em); top: -2.824em; left: 0.536em;"><span class="mn" id="MathJax-Span-634" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.563em -0.477em); top: -2.131em; left: 0.536em;"><span class="mn" id="MathJax-Span-635" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">2</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.537em; vertical-align: -0.397em;"></span></span></nobr>
<script type="math/tex; mode=display" id="MathJax-Element-37">L_i = \Vert f - y_i \Vert_2^2</script>

而L1范数损失函数是如下的形式:

<nobr><span class="math" id="MathJax-Span-636" style="width: 16.909em; display: inline-block;"><span style="display: inline-block; position: relative; width: 13.496em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.443em 1000em 4.109em -0.477em); top: -2.557em; left: 0.003em;"><span class="mrow" id="MathJax-Span-637"><span class="msubsup" id="MathJax-Span-638"><span style="display: inline-block; position: relative; width: 0.856em; height: 0px;"><span style="position: absolute; clip: rect(1.763em 1000em 2.723em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-639" style="font-family: STIXGeneral-Italic;">L<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.237em; left: 0.589em;"><span class="mi" id="MathJax-Span-640" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-641" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="mo" id="MathJax-Span-642" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">∥</span><span class="mi" id="MathJax-Span-643" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span class="mo" id="MathJax-Span-644" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="msubsup" id="MathJax-Span-645" style="padding-left: 0.269em;"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-646" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="mi" id="MathJax-Span-647" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-648"><span style="display: inline-block; position: relative; width: 0.963em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.371em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-649" style="font-family: STIXGeneral-Regular;">∥</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.184em; left: 0.536em;"><span class="mn" id="MathJax-Span-650" style="font-size: 70.7%; font-family: STIXGeneral-Regular;">1</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-651" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">=</span><span class="munderover" id="MathJax-Span-652" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 1.336em; height: 0px;"><span style="position: absolute; clip: rect(1.869em 1000em 3.629em -0.371em); top: -2.984em; left: 0.003em;"><span class="mo" id="MathJax-Span-653" style="font-family: STIXSizeOneSym; vertical-align: -0.531em;">∑</span><span style="display: inline-block; width: 0px; height: 2.989em;"></span></span><span style="position: absolute; clip: rect(1.763em 1000em 2.829em -0.584em); top: -1.224em; left: 0.536em;"><span class="mi" id="MathJax-Span-654" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-655" style="font-family: STIXGeneral-Regular; padding-left: 0.323em;">∣</span><span class="msubsup" id="MathJax-Span-656" style="padding-left: 0.323em;"><span style="display: inline-block; position: relative; width: 0.643em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.936em -0.637em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-657" style="font-family: STIXGeneral-Italic;">f<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.163em;"></span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.323em;"><span class="mi" id="MathJax-Span-658" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-659" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">−</span><span class="mo" id="MathJax-Span-660" style="font-family: STIXGeneral-Regular; padding-left: 0.269em;">(</span><span class="msubsup" id="MathJax-Span-661"><span style="display: inline-block; position: relative; width: 0.749em; height: 0px;"><span style="position: absolute; clip: rect(1.976em 1000em 2.936em -0.477em); top: -2.557em; left: 0.003em;"><span class="mi" id="MathJax-Span-662" style="font-family: STIXGeneral-Italic;">y</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.131em; left: 0.483em;"><span class="mi" id="MathJax-Span-663" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">i</span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-664"><span style="display: inline-block; position: relative; width: 0.643em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.883em -0.424em); top: -2.557em; left: 0.003em;"><span class="mo" id="MathJax-Span-665" style="font-family: STIXGeneral-Regular;">)</span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; top: -2.184em; left: 0.376em;"><span class="mi" id="MathJax-Span-666" style="font-size: 70.7%; font-family: STIXGeneral-Italic;">j<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 2.403em;"></span></span></span></span><span class="mo" id="MathJax-Span-667" style="font-family: STIXGeneral-Regular;">∣</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 3.07em; vertical-align: -1.797em;"></span></span></nobr>
<script type="math/tex; mode=display" id="MathJax-Element-38">L_i = \Vert f - y_i \Vert_1 = \sum_j \mid f_j - (y_i)_j \mid</script>

注意

  • 回归问题中用到的L2范数损失,比分类问题中的Softmax分类器用到的损失函数,更难优化。直观想一想这个问题,一个神经网络最后输出离散的判定类别,比训练它去输出一个个和样本结果对应的连续值,要简单多了。
  • 另外一个,前面的博文中提到过,其实Softmax这种分类器,对于输出的打分结果具体值是不怎么在乎的,它只在乎各个类别之间的打分幅度有没有差很多(比如二分类两个类别的得分是1和9,与0.1和0.9)。
  • 再一个,L2范数损失健壮性更差一些,异常点和噪声都可能改变损失函数的幅度,而带来大的梯度偏差。
  • 一般情况下,对于回归问题,我们都会首先考虑,这个问题能否转化成对应的分类问题,比如说我们把输出值划分成不同的区域(切成一些桶)。举个例子,如果我们要预测一个产品的预测打分,我们可以考虑把得分结果分成1-5颗星,而转化成一个分类问题。
  • 如果你觉得问题确实没办法转化成分类问题,那要小心使用L2范数损失:举个例子,在神经网络中,在L2损失函数之前使用dropout是不合适的。

如果我们遇到回归问题,首先要想想,是否完全没有可能把结果离散化之后,把这个问题转化成一个分类问题。

3. 总结

总结一下:

  • 在很多神经网络的问题中,我们都建议对数据特征做预处理,去均值,然后归一化到[-1,1]之间。
  • 从一个标准差为<nobr><span class="math" id="MathJax-Span-668" style="width: 2.616em; display: inline-block;"><span style="display: inline-block; position: relative; width: 2.083em; height: 0px; font-size: 125%;"><span style="position: absolute; clip: rect(1.389em 1000em 2.669em -0.424em); top: -2.344em; left: 0.003em;"><span class="mrow" id="MathJax-Span-669"><span class="msqrt" id="MathJax-Span-670"><span style="display: inline-block; position: relative; width: 2.029em; height: 0px;"><span style="position: absolute; clip: rect(1.709em 1000em 2.723em -0.424em); top: -2.557em; left: 0.749em;"><span class="mrow" id="MathJax-Span-671"><span class="mn" id="MathJax-Span-672" style="font-family: STIXGeneral-Regular;">2</span><span class="texatom" id="MathJax-Span-673"><span class="mrow" id="MathJax-Span-674"><span class="mo" id="MathJax-Span-675" style="font-family: STIXGeneral-Regular;">/</span></span></span><span class="mi" id="MathJax-Span-676" style="font-family: STIXGeneral-Italic;">n</span></span><span style="display: inline-block; width: 0px; height: 2.563em;"></span></span><span style="position: absolute; clip: rect(3.043em 1000em 3.416em -0.477em); top: -3.997em; left: 0.749em;"><span style="display: inline-block; position: relative; width: 1.283em; height: 0px;"><span style="position: absolute; font-family: STIXGeneral-Regular; top: -3.997em; left: 0.003em;">‾<span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; font-family: STIXGeneral-Regular; top: -3.997em; left: 0.803em;">‾<span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="font-family: STIXGeneral-Regular; position: absolute; top: -3.997em; left: 0.376em;">‾<span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span><span style="position: absolute; clip: rect(2.883em 1000em 4.163em -0.424em); top: -3.891em; left: 0.003em;"><span style="font-family: STIXVariants;">√</span><span style="display: inline-block; width: 0px; height: 4.003em;"></span></span></span></span></span><span style="display: inline-block; width: 0px; height: 2.349em;"></span></span></span><span style="border-left-width: 0.003em; border-left-style: solid; display: inline-block; overflow: hidden; width: 0px; height: 1.337em; vertical-align: -0.263em;"></span></span></nobr><script type="math/tex" id="MathJax-Element-39">\sqrt{2/n}</script>的高斯分布中初始化权重,其中n为输入的个数。
  • 使用L2正则化(或者最大范数约束)和dropout来减少神经网络的过拟合。
  • 对于分类问题,我们最常见的损失函数依旧是SVM hinge loss和Softmax互熵损失。
<script type="text/javascript"> $(function () { $('pre.prettyprint code').each(function () { var lines = $(this).text().split('\n').length; var $numbering = $('<ul/>').addClass('pre-numbering').hide(); $(this).addClass('has-numbering').parent().append($numbering); for (i = 1; i <= lines; i++) { $numbering.append($('<li/>').text(i)); }; $numbering.fadeIn(1700); }); }); </script>
分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics