AdaGradの数式

SGDの欠点を改善するために代わる手法としてエーダグラッドがあります。AdaGradを”エーダグラッド”という読み方が正しいかはさて置きw
まずは恒例の数式の話です。数式の”◎”は行列の要素ごとの乗算とのことです。こちらの数式について「0からDeepLearning」に詳しく載っております。行列計算のやり方についても優しく解説されていますので、「行列計算って何?」という方も四則演算ができるのであれば、難しくないと思います。

AdaGradのサンプルプログラム

「0からDeepLearning」に記載のプログラムは下記になります。初期関数の”lr”は学習係数になります。
行列計算がNumPyによってさり気なく行われているので数式が、どのようにプログラムに落とし込まれているのかピンとこないですが・・・プログラムは読めますね。
一番最後の行にある”1e-7”ですが、これは0で除算することを防ぐために、ごくごく小さい固定の数値を足しているとのことです。
「数式に、そんな数字はなかった!」と焦らないで大丈夫ですw
class AdaGrad:
def __init__(self, lr=0.01):
self.lr = lr
self.h = None
def update(self, params, grads):
if self.h is None:
self.h = {}
for key, val in params.items():
self.h[key] = np.zeros_like(val)
for key in params.keys():
self.h[key] += grads[key] * grads[key]
params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)AdaGradのグラフ

最後にグラフがどのように表示されるのかをご紹介します。
このグラフは、以下のプログラムを実行した結果です。
モーメンタム(Momentum)や確率的勾配降下法(SGD)に比べると(0,0)の最小値に向かって効率的に動いているのが分かります。ほぼジグザクしていません。y軸方向への更新が、最初の数回で最小値が決まってしまっている点が特徴的です。
import sys, os
sys.path.append(os.pardir)
import matplotlib.pyplot as plt
from collections import OrderedDict
from common.optimizer import *
def f(x, y):
return x ** 2 / 20.0 + y ** 2
def df(x, y):
return x / 10.0, 2.0 * y
init_pos = (-7.0, 2.0)
params = {}
params['x'], params['y'] = init_pos[0], init_pos[1]
grads = {}
grads['x'], grads['y'] = 0, 0
optimizers = OrderedDict()
optimizers["AdaGrad"] = AdaGrad(lr=1.5)
idx = 1
for key in optimizers:
optimizer = optimizers[key]
x_history = []
y_history = []
params['x'], params['y'] = init_pos[0], init_pos[1]
for i in range(10000):
x_history.append(params['x'])
y_history.append(params['y'])
grads['x'], grads['y'] = df(params['x'], params['y'])
optimizer.update(params, grads)
x = np.arange(-10, 10, 0.01)
y = np.arange(-5, 5, 0.01)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
# for simple contour line
mask = Z > 7
Z[mask] = 0
# plot
plt.subplot(1, 1, idx)
idx += 1
plt.plot(x_history, y_history, 'o-', color="red")
plt.contour(X, Y, Z)
plt.ylim(-10, 10)
plt.xlim(-10, 10)
plt.plot(0, 0, '+')
# colorbar()
# spring()
plt.title(key)
plt.xlabel("x")
plt.ylabel("y")
plt.show()
オライリーJapan 斎藤康毅 著 「セロから作るDeep Learning」