icon

エネルギーベースモデルをノリで実装する

公開: 2025-06-15 / 最終更新: 2025-07-12
EBM生成モデルPyTorch

EBMの理論を簡単に学んでMNIST手書き数字を生成する。

EBM: エネルギーベースモデル

エネルギー関数fθ(x):RdRf_\theta(\bm x): \R^d \to \Rを用いて表される次の確率モデルをエネルギーベースモデル(EBM; Energy-Based Model)と呼ぶ。

pθ(x)=exp(fθ(x))ZθZθ=exp(fθ(x))dx\begin{align} p_\theta(\bm x) &= \frac{\exp(-f_\theta(\bm x))}{Z_\theta} \\ Z_\theta &= \int \exp(-f_\theta(\bm x)) \, d\bm x \end{align}

このモデル化をすることで、パラメータθ\thetaが記述する関数の制約が緩くなり、選択の自由度が上がる。本来であれば、規格化をZZにぶん投げるとしても分子は非負性を満たす必要があったが、exponentialによってその制約がなくなっている。

fθ(x)f_\theta(\bm x)はエネルギー関数と呼ばれ、入力x\bm xのエネルギーを表す。このモデルは統計力学から着想を得られたもので、エネルギーというのもそこの言葉だと思う。この確率モデルにおいては、エネルギーが低い地点で高い確率密度を持つということになる。マイナスが付いているのでfθ(x)f_\theta(\bm x)の大きさと確率密度の大きさが反転する。なおマイナスがついているのは統計力学に倣っているためで、本確率モデルにおける意味は何もない。ただただ関数fθf_\thetaの出力値が反転するだけ。

fθf_\thetaの制約が緩いため、例えばニューラルネットような表現力の高い関数もモデルに組み込める。本稿でも、パラメータθ\thetaを持つニューラルネットをエネルギー関数fθf_\thetaとして採用し、その最適化(学習)を行う。

ランジュバン・モンテカルロ法

生成モデルとして確率モデルを考える場合、そこからのサンプリングが可能でないといけない。

任意の微分可能な確率分布p(x)p(x)からのサンプルを得る方法としてランジュバン・モンテカルロ法が存在する。MCMCの一種。適当な分布から初期値x0x_0をサンプリングし、次の更新式に従ってxxを更新する。

xt:=xt1+αxlnp(xt1)+2αϵϵN(0,I)\begin{align} x_t &:= x_{t-1} + \alpha\nabla_x\ln p(x_{t-1}) + \sqrt{2\alpha}\epsilon \\ \epsilon &\sim \mathcal N(0, I) \end{align}

確率密度の対数の勾配=確率密度を大きくする方向にステップ幅α\alphaだけxxを更新し、α\alphaに応じたノイズを足す。これをKK回繰り返したxKx_Kをサンプリング結果とする。ノイズが無ければただの勾配上昇法だが、その場合得られるサンプルが確率密度関数の極値に限定されるため、ノイズを加えて極値以外にも辿り着けるようにしている。 α0\alpha\to0KK\to\inftyの極限でxKx_Kp(x)p(x)からのサンプルに収束する事が知られている。現実的には、α\alphaを十分小さく、KKを十分大きくすることでxKx_Kp(x)p(x)からのサンプルと見なす。

本稿で扱うニューラルネットベースのEBMにおいてもこのサンプリング手法を用いる。勾配は次のように求められる。

xlnp(x)=xlnexp(fθ(x))Zθ=xfθ(x)\begin{align} \nabla_x\ln p(\bm x) &= \nabla_x \ln \frac{\exp(-f_\theta(\bm x))}{Z_\theta} \\ &= -\nabla_x f_\theta(\bm x) \end{align}

ニューラルネットにその時刻のサンプルを突っ込んで出てきた値を微分したらいいだけ。EBMのサンプリングでは以下の更新式に基づいてx\bm xを更新する。

xt:=xt1αxfθ(xt1)+2αϵ\begin{align} \bm x_t &:= \bm x_{t-1} - \alpha\nabla_{\bm x} f_\theta(\bm x_{t-1}) + \sqrt{2\alpha}\epsilon \\ \end{align}

EBMの最尤推定

EBMのパラメータθ\thetaの最適化を確認する。最尤推定による最適化を行う。

θ\thetaはニューラルネットのパラメータなので、勾配法による最適化を図る。深層学習では最小化の目標を掲げて目的関数を設計し、勾配降下法による最適化を行うことが多い。その慣例に倣い、ここでは目的関数を負の対数尤度とし、その最小化を目指す。

X={x(1),x(2),,x(N)}X = \{\bm x^{(1)}, \bm x^{(2)}, \ldots, \bm x^{(N)}\}が得られた時のEBMの負の対数尤度は以下。

lnpθ(X)=1Nn=1Nlnpθ(x(n))=1Nn=1Nlnexp(fθ(x(n)))Zθ=1Nn=1Nfθ(x(n))+lnZθ\begin{align} -\ln p_\theta(X) &= -\frac{1}{N}\sum_{n=1}^N \ln p_\theta(\bm x^{(n)}) \\ &= -\frac{1}{N}\sum_{n=1}^N \ln \frac{\exp(-f_\theta(\bm x^{(n)}))}{Z_\theta} \\ &= \frac{1}{N}\sum_{n=1}^N f_\theta(\bm x^{(n)}) + \ln Z_\theta \end{align}

次にこれの勾配を見てみる。

θlnpθ(X)=θ1Nn=1Nfθ(x(n))+θlnZθ\begin{align} -\frac{\partial}{\partial\theta} \ln p_\theta(X) = \frac{\partial}{\partial\theta}\frac{1}{N}\sum_{n=1}^N f_\theta(\bm x^{(n)}) + \frac{\partial}{\partial\theta}\ln Z_\theta \end{align}

ここで、二項目については次のように変形できるため、

θlnZθ=1ZθθZθ=1Zθθexp(fθ(x))dx=1Zθθfθ(x)exp(fθ(x))dx=θfθ(x)exp(fθ(x))Zθdx=θfθ(x)pθ(x)dx=θEpθ(x)[fθ(x)]\begin{align} \frac{\partial}{\partial\theta} \ln Z_\theta &= \frac{1}{Z_\theta} \frac{\partial}{\partial\theta}Z_\theta \\ &= \frac{1}{Z_\theta} \int \frac{\partial}{\partial\theta} \exp(-f_\theta(\bm x)) \, d\bm x \\ &= -\frac{1}{Z_\theta} \int \frac{\partial}{\partial\theta} f_\theta(\bm x) \exp(-f_\theta(\bm x)) \, d\bm x \\ &= -\frac{\partial}{\partial\theta} \int f_\theta(\bm x) \frac{\exp(-f_\theta(\bm x))}{Z_\theta} \, d\bm x \\ &= -\frac{\partial}{\partial\theta} \int f_\theta(\bm x) p_\theta(\bm x) \, d\bm x \\ &= -\frac{\partial}{\partial\theta} \mathbb E_{p_\theta(\bm x)} [f_\theta(\bm x)] \end{align}

勾配はこのようにまとめられる。

θlnpθ(X)=θ1Nn=1Nfθ(x(n))θEpθ(x)[fθ(x)]\begin{align} -\frac{\partial}{\partial\theta} \ln p_\theta(X) = \frac{\partial}{\partial\theta}\frac{1}{N}\sum_{n=1}^N f_\theta(\bm x^{(n)}) - \frac{\partial}{\partial\theta} \mathbb E_{p_\theta(\bm x)} [f_\theta(\bm x)] \end{align}

この勾配に基づいてθ\thetaを更新すればよい。ただし二項目の期待値については解析的に得ることが困難なため、実際はモンテカルロ法で近似する。ランジュバン・モンテカルロ法を用いてモデルpθp_\thetaからNN個のデータx(1),x(2),,x(N)\bm x'^{(1)},\bm x'^{(2)},\ldots,\bm x'^{(N)}をサンプリングし、期待値を近似する。

Epθ(x)[fθ(x)]1Nn=1Nfθ(x(n))\begin{align} \mathbb E_{p_\theta(\bm x)} [f_\theta(\bm x)] \approx \frac{1}{N}\sum_{n=1}^N f_\theta(\bm x'^{(n)}) \end{align}

これを用いて改めて勾配を示す。

θlnpθ(X)θ1Nn=1Nfθ(x(n))θ1Nn=1Nfθ(x(n))\begin{align} -\frac{\partial}{\partial\theta} \ln p_\theta(X) \approx \frac{\partial}{\partial\theta}\frac{1}{N}\sum_{n=1}^N f_\theta(\bm x^{(n)}) - \frac{\partial}{\partial\theta} \frac{1}{N}\sum_{n=1}^N f_\theta(\bm x'^{(n)}) \end{align}

この勾配に基づいてθ\thetaを更新すればよい。

学習データとモデルから得られたサンプルでそれぞれエネルギーの平均を取り、それらの差分をlossとし、そこから逆伝播&パラメータ更新を行うことになる。学習データx(n)\bm x^{(n)}のエネルギーは低く(=確率密度は大きく)、モデルから得られたサンプルx(n)\bm x'^{(n)}のエネルギーは高く(=確率密度は小さく)なるように学習を行うという意味になる。

🧐

与えられたデータが学習データとモデルのどちらから得られたものかをモデルに学習させ、それを欺くようなデータを取得したのちにそれを用いて再度モデルを更新する。EBMの学習はこれを繰り返すわけだが、実はこの学習方法はGAN(敵対的生成ネットワーク)と同じである。「識別モデルを欺くようなデータの生成」のやり方が異なるだけで、EBMは識別モデルの勾配を用いるが、GANでは生成器という別のモデルをいる。EBMは識別器だけを使ったGANと見てもいいかもしれない。

PyTorch実装

実際にモデルを学習させ、MNIST手書き数字を生成してみよう。

まずエネルギー関数fθf_\thetaを適当なニューラルネットで定義する。

class EnergyFunc(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
 
    def forward(self, x):
        return self.net(x.view(len(x), -1))

CNNはなんか上手くいかなかったので全結合層のみで構築した。

次にLMC(ランジュバン・モンテカルロ)によるサンプリングを実装する。詳細はコメント参照。

def lmc(model, n_samples, K, alpha):
    model.eval()
    x = torch.rand(n_samples, 28, 28, requires_grad=True) # x0
    for _ in range(K): # K回繰り返す
        energy = model(x) # xのエネルギーを求める
        energy.sum().backward() # 逆伝播でxまで勾配を届ける
        grad = x.grad # 勾配を取得
        eps = torch.randn_like(x) # ノイズ
        x = x - alpha * grad + np.sqrt(2 * alpha) * eps # 更新
        x = x.clamp(0, 1) # 0-1にクリップ
        x = x.detach().requires_grad_(True) # 計算グラフから切断
    model.zero_grad() # モデルパラメータに流れてきた勾配をリセット
    return x

最後に学習コードを実装する。

def train(model, dataloader, optimizer, n_epochs, K, alpha):
    prog.start(
        n_iter=len(dataloader),
        n_epochs=n_epochs,
        label=["loss", "E_data", "E_model"]
    )
    for _ in range(n_epochs):
        for (x_data, _) in dataloader:
            x_model = lmc(model, len(x_data), K, alpha) # モデルからサンプルを得る
            model.train()
            energy_data = model(x_data).mean() # データ分布でのエネルギー
            energy_model = model(x_model).mean() # モデルでのエネルギー
            loss = energy_data - energy_model # 差分を取る
            loss.backward() # 逆伝播
            optimizer.step() # パラメータ更新
            optimizer.zero_grad()
            prog.update([loss.item(), energy_data.item(), energy_model.item()])

実際に学習させてみよう。

model = EnergyFunc()
optimizer = optim.Adam(model.parameters())
train(model, dataloader, optimizer, n_epochs=5, K=20, alpha=0.01)
1/5: ########## 100% [00:00:29.58] loss: -8.79, E_data: -67.02, E_model: -58.22  
2/5: ########## 100% [00:00:29.50] loss: 0.45, E_data: -37.73, E_model: -38.18  
3/5: ########## 100% [00:00:31.14] loss: 0.37, E_data: 7.21, E_model: 6.84       
4/5: ########## 100% [00:00:45.66] loss: 0.43, E_data: -62.33, E_model: -62.76 
5/5: ########## 100% [00:00:45.17] loss: 0.80, E_data: -254.31, E_model: -255.10

生成結果はこんな感じ。

Loading image...

数字を識別できないような画像も多く生成されているが、まあモデルサイズにしては頑張っている方なんですかね。知らんけど。

生成過程はこんな感じ。

Loading image...

一様分布のノイズが徐々に数字となっていく過程がわかる。

💡

このノイズが次第に消えていく過程はいわゆる拡散モデルと同じものである。EBMではモデルを微分して勾配を得るが、拡散モデルでは勾配そのものの推定を学習させ、推定した勾配を用いてデータを生成する。

EBMと識別モデル

ニューラルネットを用いた一般的な識別モデル(分類モデル)はEBMと見なせる。ここでの一般的とは以下を満たすという意味である。

  • 出力層にクラス数と同じ数のニューロンを持つ
  • softmax関数を通して確率分布を出力する
  • yy番目の出力が入力x\bm xがクラスyyに属する確率を表す
  • 交差エントロピー誤差の最小化を学習する

入力x\bm xの下でのsoftmaxに通す前のyy番目の出力をFθ(x)yF_\theta(\bm x)_yとすると、識別モデルが出力する確率分布は次のように表せる。

pθ(y)=exp(Fθ(x)y)yexp(Fθ(x)y)\begin{align} p_\theta(y) = \frac{\exp(F_\theta(\bm x)_y)}{\sum_{y'} \exp(F_\theta(\bm x)_{y'})} \end{align}

これはエネルギー関数fθ(y)=Fθ(x)yf_\theta(y) = -F_\theta(\bm x)_yとしたEBMである。

また入力x\bm xを変数とするとx\bm xyyの同時エネルギーが考えられ、

fθ(x,y)=Fθ(x)y\begin{align} f_\theta(\bm x, y) = -F_\theta(\bm x)_y \end{align}

yyの条件付き分布が定義できる。

pθ(yx)=exp(fθ(x,y))yexp(fθ(x,y))=exp(Fθ(x)y)yexp(Fθ(x)y)\begin{align} p_\theta(y|\bm x) &= \frac{\exp(-f_\theta(\bm x, y))}{\sum_{y'} \exp(-f_\theta(\bm x, y'))} \\ &= \frac{\exp(F_\theta(\bm x)_y)}{\sum_{y'} \exp(F_\theta(\bm x)_{y'})} \end{align}

つまり一般的なニューラルネットベースの識別モデルは条件付きEBMとみなせる。そして正解ラベルがone-hotベクトルのとき、交差エントロピー誤差は予測分布の負の対数尤度と一致するため、その学習もEBMと同じもの(最尤推定)と見なせる。

EBMによる条件付き生成モデル

識別モデルはpθ(yx)p_\theta(y|\bm x)であるが、条件と変数を逆にしたpθ(xy)p_\theta(\bm x|y)は条件付き生成モデルである。

pθ(xy)=exp(fθ(x,y))xexp(fθ(x,y))\begin{align} p_\theta(\bm x|y) &= \frac{\exp(-f_\theta(\bm x, y))}{\sum_{\bm x'} \exp(-f_\theta(\bm x', y))} \\ \end{align}

fθ(x,y)f_\theta(\bm x, y)x\bm xyyを受け取って何らかの実数を返す関数であれば何でもよく、例えば前章で扱った識別モデルと同じ構造のニューラルネットを用いて前章と同じようにfθ(x,y)=Fθ(x)yf_\theta(\bm x, y) = -F_\theta(\bm x)_yとすることもできる。事前分布を推定して先の条件付き分布を周辺化するとノーマルな生成モデルが得られる。

ypθ(xy)p(y)=pθ(x)\begin{align} \sum_y p_\theta(\bm x|y)p(y) = p_\theta(\bm x) \end{align}

これはニューラルネットベースの識別モデルが潜在的に生成モデルを持っていることを示している。

サンプリングも同じように行える。条件yyの下での同時エネルギーfθ(x,y)f_\theta(\bm x, y)が大きくなる方向にx\bm xを更新していけばよい。

xt:=xt1αxfθ(xt1,y)+2αϵ\begin{align} \bm x_t &:= \bm x_{t-1} - \alpha\nabla_{\bm x} f_\theta(\bm x_{t-1}, y) + \sqrt{2\alpha}\epsilon \\ \end{align}

学習についても考えてみる。識別モデルと同じ構造のニューラルネットをエネルギー関数に採用し、入力と条件(正解ラベル)のペア(x(n),y(n))(\bm x^{(n)}, y^{(n)})NN個得られた下で負の対数尤度を最小化する。

条件のないEBMの最尤推定では次の勾配に基づいてパラメータを更新した。

θlnpθ(X)θ1Nn=1Nfθ(x(n))θ1Nn=1Nfθ(x(n))\begin{align} -\frac{\partial}{\partial\theta} \ln p_\theta(X) \approx \frac{\partial}{\partial\theta}\frac{1}{N}\sum_{n=1}^N f_\theta(\bm x^{(n)}) - \frac{\partial}{\partial\theta} \frac{1}{N}\sum_{n=1}^N f_\theta(\bm x'^{(n)}) \end{align}

ここに条件y(n)y^{(n)}を付与したら良い。

θlnpθ(XY)θ1Nn=1Nfθ(x(n),y(n))θ1Nn=1Nfθ(x(n),y(n))\begin{align} -\frac{\partial}{\partial\theta} \ln p_\theta(X|Y) \approx \frac{\partial}{\partial\theta}\frac{1}{N}\sum_{n=1}^N f_\theta(\bm x^{(n)}, y^{(n)}) - \frac{\partial}{\partial\theta} \frac{1}{N}\sum_{n=1}^N f_\theta(\bm x'^{(n)}, y^{(n)}) \end{align}

x(n)\bm x’^{(n)}は条件y(n)y^{(n)}の下でのモデルからのサンプル。この勾配に基づいてモデルを学習させることで条件付き生成モデルを得る。

実際にやってみよう。生成する数字を指定できる手書き数字の生成モデルを作る。先ほど同様MNISTデータセットを用いる。

まずエネルギー関数を定義する。画像xxと条件yyを受け取ってその同時エネルギーを出力するモデルをニューラルネットで作成する。

class EnergyFunc(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
 
    def forward(self, x, y):
        energy = self.net(x.view(len(x), -1))
        energy = energy.gather(1, y.unsqueeze(1))
        return energy

ネットワーク構造は一般的な識別モデルと同じ。ニューラルネットから出力された10個の値の中で与えた条件(数字)と対応する部分だけを出力する。

次にLMCサンプリングを実装する。これはモデルにyyも与えるようにするだけ。

def lmc(model, y, K, alpha):
    model.eval()
    x = torch.rand(len(y), 28, 28, requires_grad=True)
    for _ in range(K):
        energy = model(x, y) # yも与える
        energy.sum().backward()
        grad = x.grad
        eps = torch.randn_like(x)
        x = x - alpha * grad + np.sqrt(2 * alpha) * eps
        x = x.clamp(0, 1)
        x = x.detach().requires_grad_(True)
    model.zero_grad()
    return x

最後に学習コード。これも同じで、yyも与えるようにするだけ。

def train(model, dataloader, optimizer, n_epochs, K, alpha):
    prog.start(
        n_iter=len(dataloader),
        n_epochs=n_epochs,
        label=["loss", "E_data", "E_model"]
    )
    for _ in range(n_epochs):
        for (x_data, y) in dataloader: # yも取得
            x_model = lmc(model, y, K, alpha) # yも与える
            model.train()
            energy_data = model(x_data, y).mean() # yも与える
            energy_model = model(x_model, y).mean() # yも与える
            loss = energy_data - energy_model
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            prog.update([loss.item(), energy_data.item(), energy_model.item()])

実際に生成してみよう。

model = EnergyFunc()
optimizer = optim.Adam(model.parameters())
train(model, dataloader, optimizer, n_epochs=5, K=20, alpha=0.01)
1/5: ########## 100% [00:00:30.41] loss: -14.70, E_data: -67.89, E_model: -53.20  
2/5: ########## 100% [00:00:30.06] loss: -2.05, E_data: -197.98, E_model: -195.93  
3/5: ########## 100% [00:00:30.22] loss: -1.79, E_data: -314.11, E_model: -312.32  
4/5: ########## 100% [00:00:30.23] loss: -1.73, E_data: -562.98, E_model: -561.25  
5/5: ########## 100% [00:00:29.96] loss: -4.45, E_data: -853.78, E_model: -849.33   

生成結果はこんな感じ。条件は0~9を左から順に。

Loading image...

各数字の特徴は捉えられているんじゃないかな。生成過程も見てみよう。

Loading image...

数字によって収束の早さが異なるみたいだね。

オワリ

終わりです。拡散モデルを理解する基礎が身についたかな。

参考

  1. 岡野原大輔. 拡散モデル データ生成技術の数理. 岩波書店, 2023.
  2. 東京大学 松尾・岩澤研究室. "2024Summer 深層生成モデル :第5回エネルギーベースモデル." YouTube, 2024, https://www.youtube.com/watch?v=OUa338hdpcE.