上一章中,我们介绍了神经网络的学习,并通过数值微分计算了神经网络的权重参数的梯度(严格来说,是损失函数关于权重参数的梯度)。数值微分虽然简单,也容易实现,但缺点是计算上比较费时间。本章我们将学习一个能够高效计算权重参数的梯度的方法——误差反向传播法。
正确理解误差反向传播法,我个人认为有两种方法:一种是基于数学式;另一种是基于计算图(computational graph)。
本章希望大家通过计算图,直观地理解误差反向传播法。
5.1 计算图
5.1.1 用计算图求解
问题1:太郎在超市买了2个100日元一个的苹果,消费税是10%,请计算支付金额。
可以将“2”和“1.1”分别作为变量“苹果的个数”和“消费税”标在○外面。
问题2:太郎在超市买了2个苹果、3个橘子。其中,苹果每个100日元,橘子每个150日元。消费税是10%,请计算支付金额。
综上,用计算图解题的情况下,需要按如下流程进行。
1.构建计算图。
2.在计算图上,从左向右进行计算。
这里的第2歩“从左向右进行计算”是一种正方向上的传播,简称为正向传播(forward propagation)。正向传播是从计算图出发点到结束点的传播。 既然有正向传播这个名称,当然也可以考虑反向(从图上看的话,就是从右向左)的传播。实际上,这种传播称为反向传播(backward propagation)。反向传播将在接下来的导数计算中发挥重要作用。
5.1.2 局部计算
局部计算是指,无论全局发生了什么,都能只根据与自己相关的信息输出接下来的结果。
我们用一个具体的例子来说明局部计算。比如,在超市买了2个苹果和其他很多东西
这里的重点是,各个节点处的计算都是局部计算。这意味着,例如苹果和其他很多东西的求和运算(4000 + 200 → 4200)并不关心4000这个数字是如何计算而来的,只要把两个数字相加就可以了。换言之,各个节点处只需进行与自己有关的计算(在这个例子中是对输入的两个数字进行加法运算),不用考虑全局
5.1.3 为何用计算图解题
那么计算图到底有什么优点呢?
一个优点就在于前面所说的局部计算。无论全局是多么复杂的计算,都可以通过局部计算使各个节点致力于简单的计算,从而简化问题。另一个优点是,利用计算图可以将中间的计算结果全部保存起来(比如,计算进行到2个苹果时的金额是200日元、加上消费税之前的金额650日元等)。但是只有这些理由可能还无法令人信服。实际上,使用计算图最大的原因是,可以通过反向传播高效计算导数。
这里,假设我们想知道苹果价格的上涨会在多大程度上影响最终的支付金额,即求“支付金额关于苹果的价格的导数”。设苹果的价格为x,支付金额为L,则相当于求
。这个导数的值表示当苹果的价格稍微上涨时,支付金额会增加多少。
反向传播使用与正方向相反的箭头(粗线)表示。反向传播传递“局部导数”,将导数的值写在箭头的下方。在这个例子中,反向传播从右向左传递导数的值(1 → 1.1 → 2.2)。从这个结果中可知,“支付金额关于苹果的价格的导数”的值是2.2。这意味着,如果苹果的价格上涨1日元,最终的支付金额会增加2.2日元(严格地讲,如果苹果的价格增加某个微小值,则最终的支付金额将增加那个微小值的2.2倍)。
5.2 链式法则
5.2.1 计算图的反向传播
让我们先来看一个使用计算图的反向传播的例子。假设存在y = f(x)的计算,这个计算的反向传播如图5-6所示。
如图所示,反向传播的计算顺序是,将信号E乘以节点的局部导数
,然后将结果传递给下一个节点。这里所说的局部导数是指正向传播中y = f(x)的导数,也就是y关于x的导数
。比如,假设y = f(x) =
, 则局部导数为
= 2x。把这个局部导数乘以上游传过来的值(本例中为E),然后传递给前面的节点。
5.2.2 什么是链式法则
介绍链式法则时,我们需要先从复合函数说起。复合函数是由多个函数构成的函数。比如,z = (x + y) 2 是由式(5.1)所示的两个式子构成的。
链式法则是关于复合函数的导数的性质,定义如下。
如果某个函数由复合函数表示,则该复合函数的导数可以用构成复合函数的各个函数的导数的乘积表示。
这就是链式法则的原理,乍一看可能比较难理解,但实际上它是一个非常简单的性质。以式(5.1)为例,
(z关于x的导数)可以用
(z关于t 的导数)和
(t关于x的导数)的乘积表示。用数学式表示的话,可以写成式(5.2)。
式(5.2)中的∂ t正好可以像下面这样“互相抵消”,所以记起来很简单
所以最后要计算的结果
5.2.3 链式法则和计算图
现在我们尝试将式(5.4)的链式法则的计算用计算图表示出来。如果用“**2”节点表示平方运算的话,则计算图如图5-7所示。
根据链式法则,
成立,对应“z关于x的导数”。也就是说,反向传播是基于链式法则的。
5.3 反向传播
5.3.1 加法节点的反向传播
这里以z = x + y为对象,观察它的反向传播。z = x + y的导数可由下式(解析性地)计算出来。
在图5-9中,反向传播将从上游传过来的导数(本例中是
)乘以1,然后传向下游。也就是说,因为加法节点的反向传播只乘以1,所以输入的值会原封不动地流向下一个节点。
另外,本例中把从上游传过来的导数的值设为
。这是因为,如图5-10 所示,我们假定了一个最终输出值为L的大型计算图。
的计算位于这个大型计算图的某个地方,从上游会传来
,并向下游传递
和
现在来看一个加法的反向传播的具体例子。假设有“10 + 5=15”这一计算,反向传播时,从上游会传来值1.3。用计算图表示的话,如图5-11所示。
5.3.2 乘法节点的反向传播
这里我们考虑z = xy。这个式子的导数用式(5.6)表示。
乘法的反向传播会将上游的值乘以正向传播时的输入信号的“翻转值”后传递给下游。翻转值表示一种翻转关系,如图5-12所示,正向传播时信号是x的话,反向传播时则是y;正向传播时信号是y的话,反向传播时则是x。
现在我们来看一个具体的例子。比如,假设有“10 × 5 = 50”这一计算,反向传播时,从上游会传来值1.3。用计算图表示的话,如图5-13所示。
因为乘法的反向传播会乘以输入信号的翻转值,所以各自可按1.3 × 5 = 6.5、1.3 × 10 = 13计算。另外,加法的反向传播只是将上游的值传给下游,并不需要正向传播的输入信号。但是,乘法的反向传播需要正向传播时的输入信号值。因此,实现乘法节点的反向传播时,要保存正向传播的输入信号。
5.3.3 苹果的例子
苹果的例子(2个苹果和消费税)。这里要解的问题是苹果的价格、苹果的个数、消费税这3个变量各自如何影响最终支付的金额。这个问题相当于求“支付金额关于苹果的价格的导数”“支付金额关于苹果的个数的导数”“支付金额关于消费税的导数”。用计算图的反向传播来解的话,求解过程如图5-14所示。
结果可知,苹果的价格的导数是2.2,苹果的个数的导数是110,消费税的导数是200。这可以解释为,如果消费税和苹果的价格增加相同的值,则消费税将对最终价格产生200倍大小的影响,苹果的价格将产生2.2倍大小的影响。