エネルギーベースモデルをノリで実装する
EBMの理論を簡単に学んでMNIST手書き数字を生成する。
EBM: エネルギーベースモデル
エネルギー関数を用いて表される次の確率モデルをエネルギーベースモデル(EBM; Energy-Based Model)と呼ぶ。
このモデル化をすることで、パラメータが記述する関数の制約が緩くなり、選択の自由度が上がる。本来であれば、規格化をにぶん投げるとしても分子は非負性を満たす必要があったが、exponentialによってその制約がなくなっている。
はエネルギー関数と呼ばれ、入力のエネルギーを表す。このモデルは統計力学から着想を得られたもので、エネルギーというのもそこの言葉だと思う。この確率モデルにおいては、エネルギーが低い地点で高い確率密度を持つということになる。マイナスが付いているのでの大きさと確率密度の大きさが反転する。なおマイナスがついているのは統計力学に倣っているためで、本確率モデルにおける意味は何もない。ただただ関数の出力値が反転するだけ。
の制約が緩いため、例えばニューラルネットような表現力の高い関数もモデルに組み込める。本稿でも、パラメータを持つニューラルネットをエネルギー関数として採用し、その最適化(学習)を行う。
ランジュバン・モンテカルロ法
生成モデルとして確率モデルを考える場合、そこからのサンプリングが可能でないといけない。
任意の微分可能な確率分布からのサンプルを得る方法としてランジュバン・モンテカルロ法が存在する。MCMCの一種。適当な分布から初期値をサンプリングし、次の更新式に従ってを更新する。
確率密度の対数の勾配=確率密度を大きくする方向にステップ幅だけを更新し、に応じたノイズを足す。これを回繰り返したをサンプリング結果とする。ノイズが無ければただの勾配上昇法だが、その場合得られるサンプルが確率密度関数の極値に限定されるため、ノイズを加えて極値以外にも辿り着けるようにしている。 、の極限ではからのサンプルに収束する事が知られている。現実的には、を十分小さく、を十分大きくすることでをからのサンプルと見なす。
本稿で扱うニューラルネットベースのEBMにおいてもこのサンプリング手法を用いる。勾配は次のように求められる。
ニューラルネットにその時刻のサンプルを突っ込んで出てきた値を微分したらいいだけ。EBMのサンプリングでは以下の更新式に基づいてを更新する。
EBMの最尤推定
EBMのパラメータの最適化を確認する。最尤推定による最適化を行う。
はニューラルネットのパラメータなので、勾配法による最適化を図る。深層学習では最小化の目標を掲げて目的関数を設計し、勾配降下法による最適化を行うことが多い。その慣例に倣い、ここでは目的関数を負の対数尤度とし、その最小化を目指す。
が得られた時のEBMの負の対数尤度は以下。
次にこれの勾配を見てみる。
ここで、二項目については次のように変形できるため、
勾配はこのようにまとめられる。
この勾配に基づいてを更新すればよい。ただし二項目の期待値については解析的に得ることが困難なため、実際はモンテカルロ法で近似する。ランジュバン・モンテカルロ法を用いてモデルから個のデータをサンプリングし、期待値を近似する。
これを用いて改めて勾配を示す。
この勾配に基づいてを更新すればよい。
学習データとモデルから得られたサンプルでそれぞれエネルギーの平均を取り、それらの差分をlossとし、そこから逆伝播&パラメータ更新を行うことになる。学習データのエネルギーは低く(=確率密度は大きく)、モデルから得られたサンプルのエネルギーは高く(=確率密度は小さく)なるように学習を行うという意味になる。
🧐
与えられたデータが学習データとモデルのどちらから得られたものかをモデルに学習させ、それを欺くようなデータを取得したのちにそれを用いて再度モデルを更新する。EBMの学習はこれを繰り返すわけだが、実はこの学習方法はGAN(敵対的生成ネットワーク)と同じである。「識別モデルを欺くようなデータの生成」のやり方が異なるだけで、EBMは識別モデルの勾配を用いるが、GANでは生成器という別のモデルをいる。EBMは識別器だけを使ったGANと見てもいいかもしれない。
PyTorch実装
実際にモデルを学習させ、MNIST手書き数字を生成してみよう。
まずエネルギー関数を適当なニューラルネットで定義する。
class EnergyFunc(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(28 * 28, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, x):
return self.net(x.view(len(x), -1))
CNNはなんか上手くいかなかったので全結合層のみで構築した。
次にLMC(ランジュバン・モンテカルロ)によるサンプリングを実装する。詳細はコメント参照。
def lmc(model, n_samples, K, alpha):
model.eval()
x = torch.rand(n_samples, 28, 28, requires_grad=True) # x0
for _ in range(K): # K回繰り返す
energy = model(x) # xのエネルギーを求める
energy.sum().backward() # 逆伝播でxまで勾配を届ける
grad = x.grad # 勾配を取得
eps = torch.randn_like(x) # ノイズ
x = x - alpha * grad + np.sqrt(2 * alpha) * eps # 更新
x = x.clamp(0, 1) # 0-1にクリップ
x = x.detach().requires_grad_(True) # 計算グラフから切断
model.zero_grad() # モデルパラメータに流れてきた勾配をリセット
return x
最後に学習コードを実装する。
def train(model, dataloader, optimizer, n_epochs, K, alpha):
prog.start(
n_iter=len(dataloader),
n_epochs=n_epochs,
label=["loss", "E_data", "E_model"]
)
for _ in range(n_epochs):
for (x_data, _) in dataloader:
x_model = lmc(model, len(x_data), K, alpha) # モデルからサンプルを得る
model.train()
energy_data = model(x_data).mean() # データ分布でのエネルギー
energy_model = model(x_model).mean() # モデルでのエネルギー
loss = energy_data - energy_model # 差分を取る
loss.backward() # 逆伝播
optimizer.step() # パラメータ更新
optimizer.zero_grad()
prog.update([loss.item(), energy_data.item(), energy_model.item()])
実際に学習させてみよう。
model = EnergyFunc()
optimizer = optim.Adam(model.parameters())
train(model, dataloader, optimizer, n_epochs=5, K=20, alpha=0.01)
1/5: ########## 100% [00:00:29.58] loss: -8.79, E_data: -67.02, E_model: -58.22
2/5: ########## 100% [00:00:29.50] loss: 0.45, E_data: -37.73, E_model: -38.18
3/5: ########## 100% [00:00:31.14] loss: 0.37, E_data: 7.21, E_model: 6.84
4/5: ########## 100% [00:00:45.66] loss: 0.43, E_data: -62.33, E_model: -62.76
5/5: ########## 100% [00:00:45.17] loss: 0.80, E_data: -254.31, E_model: -255.10
生成結果はこんな感じ。
Loading image...
数字を識別できないような画像も多く生成されているが、まあモデルサイズにしては頑張っている方なんですかね。知らんけど。
生成過程はこんな感じ。
Loading image...
一様分布のノイズが徐々に数字となっていく過程がわかる。
💡
このノイズが次第に消えていく過程はいわゆる拡散モデルと同じものである。EBMではモデルを微分して勾配を得るが、拡散モデルでは勾配そのものの推定を学習させ、推定した勾配を用いてデータを生成する。
EBMと識別モデル
ニューラルネットを用いた一般的な識別モデル(分類モデル)はEBMと見なせる。ここでの一般的とは以下を満たすという意味である。
- 出力層にクラス数と同じ数のニューロンを持つ
- softmax関数を通して確率分布を出力する
- 番目の出力が入力がクラスに属する確率を表す
- 交差エントロピー誤差の最小化を学習する
入力の下でのsoftmaxに通す前の番目の出力をとすると、識別モデルが出力する確率分布は次のように表せる。
これはエネルギー関数としたEBMである。
また入力を変数とするととの同時エネルギーが考えられ、
の条件付き分布が定義できる。
つまり一般的なニューラルネットベースの識別モデルは条件付きEBMとみなせる。そして正解ラベルがone-hotベクトルのとき、交差エントロピー誤差は予測分布の負の対数尤度と一致するため、その学習もEBMと同じもの(最尤推定)と見なせる。
EBMによる条件付き生成モデル
識別モデルはであるが、条件と変数を逆にしたは条件付き生成モデルである。
はとを受け取って何らかの実数を返す関数であれば何でもよく、例えば前章で扱った識別モデルと同じ構造のニューラルネットを用いて前章と同じようにとすることもできる。事前分布を推定して先の条件付き分布を周辺化するとノーマルな生成モデルが得られる。
これはニューラルネットベースの識別モデルが潜在的に生成モデルを持っていることを示している。
サンプリングも同じように行える。条件の下での同時エネルギーが大きくなる方向にを更新していけばよい。
学習についても考えてみる。識別モデルと同じ構造のニューラルネットをエネルギー関数に採用し、入力と条件(正解ラベル)のペアが個得られた下で負の対数尤度を最小化する。
条件のないEBMの最尤推定では次の勾配に基づいてパラメータを更新した。
ここに条件を付与したら良い。
は条件の下でのモデルからのサンプル。この勾配に基づいてモデルを学習させることで条件付き生成モデルを得る。
実際にやってみよう。生成する数字を指定できる手書き数字の生成モデルを作る。先ほど同様MNISTデータセットを用いる。
まずエネルギー関数を定義する。画像と条件を受け取ってその同時エネルギーを出力するモデルをニューラルネットで作成する。
class EnergyFunc(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(28 * 28, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
def forward(self, x, y):
energy = self.net(x.view(len(x), -1))
energy = energy.gather(1, y.unsqueeze(1))
return energy
ネットワーク構造は一般的な識別モデルと同じ。ニューラルネットから出力された10個の値の中で与えた条件(数字)と対応する部分だけを出力する。
次にLMCサンプリングを実装する。これはモデルにも与えるようにするだけ。
def lmc(model, y, K, alpha):
model.eval()
x = torch.rand(len(y), 28, 28, requires_grad=True)
for _ in range(K):
energy = model(x, y) # yも与える
energy.sum().backward()
grad = x.grad
eps = torch.randn_like(x)
x = x - alpha * grad + np.sqrt(2 * alpha) * eps
x = x.clamp(0, 1)
x = x.detach().requires_grad_(True)
model.zero_grad()
return x
最後に学習コード。これも同じで、も与えるようにするだけ。
def train(model, dataloader, optimizer, n_epochs, K, alpha):
prog.start(
n_iter=len(dataloader),
n_epochs=n_epochs,
label=["loss", "E_data", "E_model"]
)
for _ in range(n_epochs):
for (x_data, y) in dataloader: # yも取得
x_model = lmc(model, y, K, alpha) # yも与える
model.train()
energy_data = model(x_data, y).mean() # yも与える
energy_model = model(x_model, y).mean() # yも与える
loss = energy_data - energy_model
loss.backward()
optimizer.step()
optimizer.zero_grad()
prog.update([loss.item(), energy_data.item(), energy_model.item()])
実際に生成してみよう。
model = EnergyFunc()
optimizer = optim.Adam(model.parameters())
train(model, dataloader, optimizer, n_epochs=5, K=20, alpha=0.01)
1/5: ########## 100% [00:00:30.41] loss: -14.70, E_data: -67.89, E_model: -53.20
2/5: ########## 100% [00:00:30.06] loss: -2.05, E_data: -197.98, E_model: -195.93
3/5: ########## 100% [00:00:30.22] loss: -1.79, E_data: -314.11, E_model: -312.32
4/5: ########## 100% [00:00:30.23] loss: -1.73, E_data: -562.98, E_model: -561.25
5/5: ########## 100% [00:00:29.96] loss: -4.45, E_data: -853.78, E_model: -849.33
生成結果はこんな感じ。条件は0~9を左から順に。
Loading image...
各数字の特徴は捉えられているんじゃないかな。生成過程も見てみよう。
Loading image...
数字によって収束の早さが異なるみたいだね。
オワリ
終わりです。拡散モデルを理解する基礎が身についたかな。
参考
- 岡野原大輔. 拡散モデル データ生成技術の数理. 岩波書店, 2023.
- 東京大学 松尾・岩澤研究室. "2024Summer 深層生成モデル :第5回エネルギーベースモデル." YouTube, 2024, https://www.youtube.com/watch?v=OUa338hdpcE.