0%

Optimization III - Statistical Manifold and Natural Gradient Descent

Perspectives on Natural Gradient Descent

In our article Optimization I - Gradient Descent, we talk about how to view the gradient descent as the approximation of Riemannian gradient descent. However, is there a descent method more loylal to the Riemannian gradient descent? Or in other word, better than gradient descent in some cases. In this artile, we will start our journey to the beautiful and attracting information geometry world. In that abstract and fancy world, the secrete equivalence between several gradient descent methods are discovered.

KL-Divergence

Given two distributions \(p( y \mid x), p(y \mid x')\), the KL divergence measures the dsitance between these two distrbutions. \[\begin{align*} KL(p( y \mid x), p(y \mid x')) &= \mathbb{E}[\log (p(y \mid x))] - \mathbb{E}[\log (p(y \mid x'))] \\ &= \sum_{y} p(y \mid x) \log (\frac{p( y \mid x)}{p(y \mid x')}) \end{align*}\] There have been multiple extra information about KL divergence, one can refer to wiki or Jake Tae's Blog for more information.

From the formulation of KL divergence we can conclude that \[ \nabla_{x'} ^2 KL(p( y \mid x), p(y \mid x')) = -\mathbb{E}[\nabla ^2 \log (p(y \mid x'))] = F. \] Here \(F\) is the Fisher matrix. The connection between \(F\) and \(KL\) also exists in the following equation: \[ KL(p( y \mid x), p(y \mid x+\epsilon)) \approx \frac{1}{2} \epsilon^T F \epsilon. \] Actually this equation tells us that in a manifold where we measure the distance using KL-divergemce, the Riemanian metric is \(F\).

Proximal View with Generalized Distance -- "Divergence"

\[\begin{align*} \min \ &f(x_k) + \langle \nabla f(x_k), (x - x_k) \rangle \\ s.t. \ \ &x \in \mathcal{M} \\ &x \in B_\epsilon(x_k) \end{align*}\]

Recall that we used to derive the gradient descent(GD) from above problem. Here how we decide the bbll \(B_\epsilon(x_k)\) defined at \(x_k\) is concernable. In our first post we use \(l_2\) norm to measure the distance. But selecting a norm without considering the real meaning behind points is suspicious. What if the point representing a distribution? In this case, a better way to compare the distance between points is the divergence.

Divergence is a generalized distance. Particularly here we use the Bregman divergence. Given function convex \(h \in C^3\), Bregman divergence \(B_h\) is defined as: \[ B_h(x, y) = h(y) - h(x) - \langle \nabla h(x), y-x \rangle. \] Intuitively, you can understand the Bragman divergence as the distance between \(h(y)\) and the \(L_x(y)\), where \(L_x(y)\) is the linear approximatio of \(h(x)\) at point \(x\). Click wikipedia to see the picture.

To save our time, I will not repeat the procedure to derive the formula. Here I simply provide the result \[ x_{k+1} = x_k - \nabla^2 h(x_k)^{-1} f(x_k) \]

Further, if we consider the manifold of point as a statistical manifold where each point represents a parameter distribution, then we can use KL divergence as an instance of Bregman divergence by taking \(h(x) = \sum_{i}x_i\log x_i\). Now the equation is \[ x_{k+1} = x_k - F(x_k)^{-1}f(x_k). \] Here the function \(F\) is fisher matrix.

What's More?

Recommend to read

Jake Tae's Blog