icon

正規化層の動きを理解する

公開: 2024-03-01 / 最終更新: 2025-05-24
PyTorch深層学習正規化層

正規化層(バッチ正規化、層正規化)がどう動いているのかを理解し、PyTorchで実装する。

正規化層

Normalization Layers

深層学習における、入力データを正規化して出力する層。色々な種類があり、正規化する軸や細かい挙動が異なる。

正規化層をニューラルネットワークに取り入れることで嬉しいことがいっぱい起こる。学習が安定したり速くなったりするらしい。ただそこら辺の話はここではしない。ここでは、それぞれの正規化層がどのように動いているかをまとめる。またそれをPyTorchで実装し、挙動を確認する。

⚠️

正規化とはデータのスケールを揃える操作のことで、平均・分散を揃える操作や最大・最小値を揃える操作などがある。平均・分散を揃える操作は標準化と呼ばれることもある。ただ本稿においては、平均を0、分散を1に揃える操作を正規化とし、それ以外の意味を持たないものとする。

バッチ正規化

Batch Normalization

バッチ内を正規化し、学習したパラメータに従ってスケーリング・シフトして出力する層[?]。正規化は特徴量ごとに行う。バッチサイズmmのミニバッチB={x(1),x(2),,x(m)}\mathcal B = \{\bm x^{(1)},\bm x^{(2)},\cdots,\bm x^{(m)} \}が得られた時、以下の演算で入力x(n)\bm x^{(n)}に対する出力値y(n)\bm y^{(n)}を決定する。

μB=1mi=1mx(i)σB2=1mi=1m(x(i)μB)2x^(n)=x(n)μBσB2+ϵy(n)=γx^(n)+β\begin{align} \bm\mu_{\mathcal B} &= \frac{1}{m}\sum_{i=1}^m \bm x^{(i)} \\ \bm\sigma^2_{\mathcal B} &= \frac{1}{m}\sum_{i=1}^m (\bm x^{(i)} - \bm\mu_{\mathcal B})^2 \\ \hat{\bm x}^{(n)} &= \frac{\bm x^{(n)} - \bm\mu_{\mathcal B}}{\sqrt{\bm\sigma^2_{\mathcal B} + \epsilon}} \\ \bm y^{(n)} &= \bm\gamma\odot\hat{\bm x}^{(n)} + \bm\beta \end{align}
  • x(n),x^(n),y(n),μB,σB2,γ,βRd\bm x^{(n)}, \, \hat{\bm x}^{(n)}, \, \bm y^{(n)}, \, \bm\mu_{\mathcal B}, \, \bm\sigma^2_{\mathcal B}, \, \bm\gamma, \, \bm\beta \in \R^d
  • dd: 特徴量の数
  • ϵ\epsilon: 微小値(0除算回避用)

まずミニバッチB\mathcal Bの平均μB\bm\mu_{\mathcal B}と分散σB2\bm\sigma^2_{\mathcal B}を求める。次にそれらを用いてx(n)\bm x^{(n)}を正規化する。最後にγ,β\bm\gamma,\bm\betaを用いてスケーリングとシフトを行う。γ,β\bm\gamma,\bm\betaは学習可能なパラメータで、出力データy(n)\bm y^{(n)}の分散と平均を意味する。まとめると、この層は、分布(分散γ\bm\gamma、平均β\bm\beta)を学習し、その分布に従うように入力データを変換する層ということ。


さて、上記の演算は学習時に行うもので、推論時には使えない。ミニバッチ内の他のデータに依って出力が変わってしまうため。推論時はy(n)\bm y^{(n)}x(n)\bm x^{(n)}のみに依存している必要がある。また推論時はバッチサイズが1であることも多く、その場合x^(n)\hat{\bm x}^{(n)}0\bm 0になるためy(n)\bm y^{(n)}β\bm\betaに固定されてしまう。

推論時はx^(n)\hat{\bm x}^{(n)}が以下に変わる。

x^(n)=x(n)E[x]Var[x]+ϵ\begin{align} \hat{\bm x}^{(n)} = \frac{\bm x^{(n)} - \mathbb E[\bm x]}{\sqrt{\text{Var}[\bm x] + \epsilon}} \end{align}

入力データx\bm xの平均E[x]\mathbb E[\bm x]と分散Var[x]\text{Var}[\bm x]を用いる。これらは学習時に観測したデータから求める。つまり実質的に学習データ全体の平均と分散となる

CNNでのバッチ正規化

画像や畳み込み層からの出力は3次元のデータで表される。当然これらも同じように正規化することが可能である。サンプルの形状が(c, h, w)の場合、c×h×wc\times h\times w個ずつパラメータ(平均、分散)を用意し、特徴量ごとに正規化(&スケーリング・シフト)するということ。

ただこの方法は基本的に使わず、実際はチャンネルごとに正規化する。サンプルを跨いだ同じ特徴マップの値を全て同じ種類の特徴量と見做し、特徴量ごとに正規化する。パラメータは平均と分散がチャンネルの数だけ必要になる。

確かに、各ピクセルの値を独立した別々の特徴量と見做すのは違和感があるので、チャンネルごとに正規化するのは自然に感じる。論文[?]には

For convolutional layers, we additionally want the normalization to obey the convolutional property (DeepL訳: 畳み込みレイヤーの場合、さらに正規化は畳み込みの性質に従うようにしたい)

と書いてあった。

実装

PyTorchで実装してみよう。

import torch
import torch.nn as nn

まず学習時の挙動を確認する。適当なミニバッチを用意する。

batch_size = 3
num_features = 4
x = torch.arange(num_features * batch_size)
x = x.reshape(batch_size, num_features).to(torch.float32)
x
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

パラメータも初期化しておく。

gamma = torch.ones(num_features)
beta = torch.zeros(num_features)
eps = 1e-5

平均は全て0、分散は全て1で初期化した。

では正規化層の演算を実装する。まずミニバッチの統計量を求める。

mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
mean, var
(tensor([4., 5., 6., 7.]), tensor([10.6667, 10.6667, 10.6667, 10.6667]))

dim=0でバッチ軸を指定し、特徴量ごとの平均と分散を求めた。

次は正規化。

x_hat = (x - mean) / torch.sqrt(var + eps)
x_hat
tensor([[-1.0000, -1.0000, -1.0000, -1.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  1.0000,  1.0000,  1.0000]])

平均0、分散1になった。

最後にパラメータでスケーリングとシフトを行う。

y = gamma * x_hat + beta
y
tensor([[-1.0000, -1.0000, -1.0000, -1.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  1.0000,  1.0000,  1.0000]])

今はパラメータが平均0、分散1なので変化なし。


以上がバッチ正規化の(学習時の)演算である。これをnn.Moduleとして実装してみよう。

class BatchNormalization(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
 
    def forward(self, x):
        mean = x.mean(dim=0)
        var = x.var(dim=0)
        x_hat = (x - mean) / (torch.sqrt(var) + self.eps)
        y = x_hat * self.gamma + self.beta
        return y

このように使う。

norm = BatchNormalization(num_features)
y = norm(x)
y
tensor([[-1.0000, -1.0000, -1.0000, -1.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  1.0000,  1.0000,  1.0000]], grad_fn=<AddBackward0>)

これで、簡易バッチ正規化層の完成。


さて、ちゃんとしたバッチ正規化層も作ってみよう。推論時の挙動を追加する。

class BatchNormalization(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.register_buffer("running_mean", torch.zeros(num_features))
        self.register_buffer("running_var", torch.ones(num_features))
 
    def forward(self, x):
        if self.training: # 学習
            mean = x.mean(dim=0)
            var = x.var(dim=0, unbiased=False)
            var_unbiased = x.var(dim=0, unbiased=True)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var_unbiased
        else: # 推論
            mean = self.running_mean
            var = self.running_var
        x_hat = (x - mean) / (torch.sqrt(var) + self.eps)
        y = x_hat * self.gamma + self.beta
        return y

E[x]\mathbb E[\bm x]Var[x]\text{Var}[\bm x]running_meanrunning_varとして保持し、推論時に使う。そしてそれらは学習時に都度更新する。移動平均によって動的に求めている。またVar[x]\text{Var}[\bm x]は不偏分散なのでunbiased=Trueにする。

学習モードで適当にデータを見せるとrunning_meanrunning_varが更新される。

norm = BatchNormalization(num_features)
norm.state_dict() # 初期値
OrderedDict([('gamma', tensor([1., 1., 1., 1.])),
             ('beta', tensor([0., 0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0., 0.])),
             ('running_var', tensor([1., 1., 1., 1.]))])

norm.train()
torch.manual_seed(0)
for _ in range(100):
    x = torch.randn(batch_size, num_features)
    norm(x)
norm.state_dict()
OrderedDict([('gamma', tensor([1., 1., 1., 1.])),
             ('beta', tensor([0., 0., 0., 0.])),
             ('running_mean', tensor([ 0.0527, -0.0417, -0.0876, -0.0537])),
             ('running_var', tensor([0.9879, 0.4089, 1.3890, 0.7146]))])

推論モードにするとrunning_meanrunning_varが演算に使われる。

norm.eval()
x = torch.randn(1, num_features)
y = norm(x)
y
tensor([[ 0.8423, -1.5081,  1.0178,  2.3462]], grad_fn=<AddBackward0>)

こういうこと。

gamma, beta, running_mean, running_var = norm.state_dict().values()
eps = norm.eps
y = (x - running_mean) / torch.sqrt(running_var + eps) * gamma + beta
y
tensor([[ 0.8423, -1.5081,  1.0178,  2.3462]])

先ほど実装したBatchNormalizationnn.BatchNorm1dとしてそのままPyTorchに実装されている。

挙動もほぼ同じ。

norm = nn.BatchNorm1d(num_features)
norm.state_dict()
OrderedDict([('weight', tensor([1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0., 0.])),
             ('running_var', tensor([1., 1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

norm.train()
torch.manual_seed(0)
for _ in range(100):
    x = torch.randn(batch_size, num_features)
    norm(x)
norm.state_dict()
OrderedDict([('weight', tensor([1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0.])),
             ('running_mean', tensor([ 0.0527, -0.0417, -0.0876, -0.0537])),
             ('running_var', tensor([0.9879, 0.4089, 1.3890, 0.7146])),
             ('num_batches_tracked', tensor(100))])

norm.eval()
x = torch.randn(1, num_features)
y = norm(x)
y
tensor([[ 0.8423, -1.5081,  1.0178,  2.3462]],
       grad_fn=<NativeBatchNormBackward0>)

同じ値になったね。

ちなみにPyTorchのバッチ正規化層は学習モードでバッチサイズ1のデータが与えられるとエラーを吐く。

norm.train()
try:
    norm(x)
except Exception as e:
    print(e)
Expected more than 1 value per channel when training, got input size torch.Size([1, 4])

nn.BatchNorm1dはチャンネル軸が加わった3次元のデータに対しても使用できる。チャンネルごとに正規化を行う。

⚠️

ここでのチャンネル軸というのは、サンプルを表した2階以上のテンソルの1番外側の軸のこと。CNNで画像を扱うときによく見る。画像以外で使われている場面はあまり見ないが、とりあえず公式ドキュメントでそう呼ばれていたので倣った。

チャンネル軸が加わった3次元のデータと表現したが、時系列データという解釈もできて、実際に公式ドキュメントでは3つ目の軸をsequence lengthと呼んでいる。

batch_size = 2
c = 3
num_features = 4
x = torch.arange(batch_size * c * num_features)
x = x.reshape(batch_size, c, num_features).to(torch.float32)
x
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],
 
        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

norm = nn.BatchNorm1d(c)
y = norm(x)
y
tensor([[[-1.2288, -1.0650, -0.9012, -0.7373],
         [-1.2288, -1.0650, -0.9012, -0.7373],
         [-1.2288, -1.0650, -0.9012, -0.7373]],
 
        [[ 0.7373,  0.9012,  1.0650,  1.2288],
         [ 0.7373,  0.9012,  1.0650,  1.2288],
         [ 0.7373,  0.9012,  1.0650,  1.2288]]],
       grad_fn=<NativeBatchNormBackward0>)

こういうこと。

mean = x.mean(dim=(0, 2), keepdim=True)
var = x.var(dim=(0, 2), unbiased=False, keepdim=True)
gamma = norm.weight.reshape(1, c, 1)
beta = norm.bias.reshape(1, c, 1)
mean, var
(tensor([[[ 7.5000],
          [11.5000],
          [15.5000]]]),
 tensor([[[37.2500],
          [37.2500],
          [37.2500]]]))

y = (x - mean) / torch.sqrt(var + eps) * gamma + beta
y
tensor([[[-1.2288, -1.0650, -0.9012, -0.7373],
         [-1.2288, -1.0650, -0.9012, -0.7373],
         [-1.2288, -1.0650, -0.9012, -0.7373]],
 
        [[ 0.7373,  0.9012,  1.0650,  1.2288],
         [ 0.7373,  0.9012,  1.0650,  1.2288],
         [ 0.7373,  0.9012,  1.0650,  1.2288]]], grad_fn=<AddBackward0>)

各チャンネルが2次元の場合はnn.BatchNorm2dを使う。また3次元の場合はnn.BatchNorm3dを使う。4次元以上はない。

nn.BatchNorm2dはCNNでよく使う。先で説明したCNNの場合の動作と同じ動きをする。

batch_size = 32
c, w, h = 3, 224, 224
x = torch.randn(batch_size, c, w, h)
x.shape
torch.Size([32, 3, 224, 224])

norm = nn.BatchNorm2d(c)
norm.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

パラメータの数はチャンネルの数と一緒。

y = norm(x)
y.shape
torch.Size([32, 3, 224, 224])

層正規化

Layer Normalization

サンプルごとに正規化し、特徴量ごとにスケーリング・シフトして出力する層[?]。バッチ正規化から正規化の方法が変わっただけ(スケーリング・シフトは一緒)。RNNにバッチ正規化が適用しづらいということで提案された。

層正規化はバッチ正規化同様、特徴量の数だけパラメータを持つ。また、この層は演算結果がバッチ内の他のデータに依らないため、学習時と推論時で挙動が変わらない(変える必要がない)。そのため、推論用の統計量を保持する必要がない。

μ(n)=1di=1dxi(n)σ2(n)=1di=1d(xi(n)μ(n))2x^(n)=x(n)μ(n)σ2(n)+ϵy(n)=γx^(n)+β\begin{align} \mu^{(n)} &= \frac{1}{d}\sum_{i=1}^d x_i^{(n)} \\ {\sigma^2}^{(n)} &= \frac{1}{d}\sum_{i=1}^d (x_i^{(n)} - \mu^{(n)})^2 \\ \hat{\bm x}^{(n)} &= \frac{\bm x^{(n)} - \mu^{(n)}}{\sqrt{{\sigma^2}^{(n)} + \epsilon}} \\ \bm y^{(n)} &= \bm\gamma\odot\hat{\bm x}^{(n)} + \bm\beta \end{align}
  • x(n),x^(n),y(n),γ,βRd\bm x^{(n)}, \, \hat{\bm x}^{(n)}, \, \bm y^{(n)}, \, \bm\gamma, \, \bm\beta \in \R^d
  • μ(n),σ2(n)R\mu^{(n)}, \, {\sigma^2}^{(n)} \in \R
  • x(n)=(x1(n),x2(n),,xd(n))\bm x^{(n)} = (x_1^{(n)}, x_2^{(n)}, \cdots, x_d^{(n)})^\top

実装

PyTorchで実装して挙動を確認してみよう。

適当なミニバッチを用意する。

batch_size = 4
d = 4
x = torch.arange(batch_size * d)
x = x.reshape(batch_size, d).to(torch.float32)
x
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])

パラメータの初期化。

gamma = torch.ones(d)
beta = torch.zeros(d)
eps = 1e-5

サンプルごとの統計量を求める。

mean = x.mean(dim=1, keepdim=True)
var = x.var(dim=1, unbiased=False, keepdim=True)
mean, var
(tensor([[ 1.5000],
         [ 5.5000],
         [ 9.5000],
         [13.5000]]),
 tensor([[1.2500],
         [1.2500],
         [1.2500],
         [1.2500]]))

dim=1とした(バッチ正規化では0)。

次に正規化。

x_hat = (x - mean) / torch.sqrt(var + eps)
x_hat
tensor([[-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416]])

サンプルごとに正規化された。

最後にスケーリングとシフトを行う。

y = gamma * x_hat + beta
y
tensor([[-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416]])

今はパラメータが平均0、分散1なので変化なし。

以上が層正規化の演算である。ちなみに、バッチ正規化だとこう。

norm = nn.BatchNorm1d(d)
y = norm(x)
y
tensor([[-1.3416, -1.3416, -1.3416, -1.3416],
        [-0.4472, -0.4472, -0.4472, -0.4472],
        [ 0.4472,  0.4472,  0.4472,  0.4472],
        [ 1.3416,  1.3416,  1.3416,  1.3416]],
       grad_fn=<NativeBatchNormBackward0>)

軸が違うのが分かる。

では、これをnn.Moduleとして実装してみよう。

class LayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
 
    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, unbiased=False, keepdim=True)
        x_hat = (x - mean) / (torch.sqrt(var) + self.eps)
        y = x_hat * self.gamma + self.beta
        return y
norm = LayerNorm(d)
y = norm(x)
y
tensor([[-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416]], grad_fn=<AddBackward0>)

同じ結果が得られた。

PyTorchにも実装されている。

norm = nn.LayerNorm(d)
norm.state_dict()
OrderedDict([('weight', tensor([1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0.]))])

y = norm(x)
y
tensor([[-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416]],
       grad_fn=<NativeLayerNormBackward0>)

こちらも同じ結果が得られた。

特徴量は多次元でも可能。また特徴量以外の軸を好きに足してもいい。初めに与えた形状と合致する部分を内側の軸から探してくれる。

例えばRNNではこうなる。

batch_size = 2
seq_len = 3
d = 4
x = torch.arange(batch_size * seq_len * d)
x = x.reshape(batch_size, seq_len, d).to(dtype=torch.float32)
x
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],
 
        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

norm = nn.LayerNorm(d)
y = norm(x)
y
tensor([[[-1.3416, -0.4472,  0.4472,  1.3416],
         [-1.3416, -0.4472,  0.4472,  1.3416],
         [-1.3416, -0.4472,  0.4472,  1.3416]],
 
        [[-1.3416, -0.4472,  0.4472,  1.3416],
         [-1.3416, -0.4472,  0.4472,  1.3416],
         [-1.3416, -0.4472,  0.4472,  1.3416]]],
       grad_fn=<NativeLayerNormBackward0>)

各時刻各サンプルが正規化された。

多次元の特徴量も可。

batch_size = 32
c, w, h = 3, 224, 224
x = torch.randn(batch_size, c, w, h)
 
norm = nn.LayerNorm((c, w, h))
y = norm(x)
y.shape
torch.Size([32, 3, 224, 224])

与えた形状と同じ形状のパラメータが用意される。

norm.weight.shape, norm.bias.shape
(torch.Size([3, 224, 224]), torch.Size([3, 224, 224]))

オワリ

バッチ正規化と層正規化の動きをまとめた。正規化層には他にもインスタンス正規化[?]やグループ正規化[?]などがある。PyTorchにも実装がまとまっている: https://pytorch.org/docs/stable/nn.html#normalization-layers

本稿で使用したコードはこちら: deep-learning/normalization_layers.ipynb at main · misya11p/deep-learning

参考