Lasso回帰の数式的な理解とPythonでのスクラッチ実装。ElasticNetも。
線形回帰と正則化
Lasso回帰を学ぶ前に軽く予習をしよう。
線形回帰では入力x ∈ R m \bm x\in\R^m x ∈ R m と重みパラメータw ∈ R m \bm w\in\R^m w ∈ R m の内積をとることでターゲットy y y を予測する。
y ^ = x ⊤ w \begin{align}
\hat y = \bm x^\top\bm w
\end{align} y ^ = x ⊤ w
入力とそれに対応するターゲットがn n n 組得られたとする。これらを
X ∈ R n × m X\in\R^{n\times m} X ∈ R n × m
y ∈ R n \bm y\in\R^n y ∈ R n
と表しておく。
最小二乗法による最適化では、予測値と正解の二乗和誤差が最小になるようなパラメータw \bm w w を求める。これは目的関数を微分して= 0 =0 = 0 とおくことで解析的に求められる。
J LR ( w ) = 1 2 ∥ y − X w ∥ 2 2 ∂ J LR ∂ w = X ⊤ X w − X ⊤ y w ^ = ( X ⊤ X ) − 1 X ⊤ y \begin{align}
J_\text{LR}(\bm w) &= \frac{1}{2} \| \bm y - X\bm w \|_2^2 \\
\frac{\partial J_\text{LR}}{\partial \bm w} &= X^\top X\bm w - X^\top\bm y \\
\hat{\bm w} &= (X^\top X)^{-1}X^\top\bm y
\end{align} J LR ( w ) ∂ w ∂ J LR w ^ = 2 1 ∥ y − X w ∥ 2 2 = X ⊤ X w − X ⊤ y = ( X ⊤ X ) − 1 X ⊤ y
過学習を防ぐため、目的関数に正則化項を足すことがある。パラメータの二乗和で正則化を行う線形回帰はRidge回帰と呼ばれる。二乗和はL2ノルムで表せるのでL2正則化とも呼ばれる。この場合も同様の流れで解析解が得られる。
J Rid ( w ) = 1 2 ∥ y − X w ∥ 2 2 + 1 2 α ∥ w ∥ 2 2 ∂ J Rid ∂ w = X ⊤ X w − X ⊤ y + α w w ^ = ( X ⊤ X + α I ) − 1 X ⊤ y \begin{align}
J_\text{Rid}(\bm w) &= \frac{1}{2} \| \bm y - X\bm w \|_2^2 + \frac{1}{2}\alpha\|\bm w\|_2^2 \\
\frac{\partial J_\text{Rid}}{\partial \bm w} &= X^\top X\bm w - X^\top\bm y + \alpha\bm w \\
\hat{\bm w} &= (X^\top X + \alpha I)^{-1}X^\top\bm y
\end{align} J Rid ( w ) ∂ w ∂ J Rid w ^ = 2 1 ∥ y − X w ∥ 2 2 + 2 1 α ∥ w ∥ 2 2 = X ⊤ X w − X ⊤ y + α w = ( X ⊤ X + α I ) − 1 X ⊤ y
Lasso回帰
Least Absolute Shrinkage and Selection Operator
Lasso回帰では正則化項にパラメータの絶対値の和を用いる。絶対値の和はL1ノルムで表せるのでL1正則化とも呼ばれる。
J Las ( w ) = 1 2 ∥ y − X w ∥ 2 2 + α ∥ w ∥ 1 ∥ w ∥ 1 = ∑ i ∣ w i ∣ \begin{align}
J_\text{Las}(\bm w) = \frac{1}{2} \| \bm y - &X \bm w \|_2^2 + \alpha \| \bm w \|_1 \\
\| \bm w \|_1 &= \sum_i |w_i|
\end{align} J Las ( w ) = 2 1 ∥ y − ∥ w ∥ 1 X w ∥ 2 2 + α ∥ w ∥ 1 = i ∑ ∣ w i ∣
Lasso回帰はRidge回帰と異なり解析解が得られない。絶対値が含まれているため微分が出来ないためである。本稿ではそんなLasso回帰の最適化について、その理論と実装方法をまとめる。
座標降下法
Coordinate Descent
最適化対象の変数を一つずつ順番に最適化する手法。Lasso回帰の最適化においてもこの手法がよく用いられる。
例えばθ = ( θ 1 , θ 2 , … , θ m ) ⊤ \bm\theta = (\theta_1,\theta_2,\ldots,\theta_m)^\top θ = ( θ 1 , θ 2 , … , θ m ) ⊤ の最適化を行うとき、θ \bm\theta θ を一度に最適化するのではなく、個別の変数θ i \theta_i θ i をi = 1 i=1 i = 1 からm m m まで順に最適化していく。θ i \theta_i θ i の最適化はそれ以外の変数θ j ( j ≠ i ) \theta_j\,(j\neq i) θ j ( j = i ) を固定して行う。これを収束するまで繰り返す。
勾配降下法同様、目的関数が凸関数でないと局所解に陥る可能性がある。一方で勾配降下法とは異なり、目的関数が微分可能でなくとも適用できるという汎用性がある。微分は出来ないが、ある一つの変数に着目した場合の最適解は求められる、といった少し特殊なケースで活躍する。
勾配降下法と座標降下法の違いはこんなイメージ。座標降下法では各ステップにおいて一つの変数しか更新されないため、いずれかの軸に対して並行に移動する。
Loading image...
劣微分
Subgradient
微分の定義を少し拡張したようなもの。微分の一般化とも見られる。関数f ( x ) f(x) f ( x ) のx = a x=a x = a における劣微分を次のように定義する。
任意のx x x に対してf ( x ) ≥ f ( a ) + c ( x − a ) f(x)\geq f(a)+c(x-a) f ( x ) ≥ f ( a ) + c ( x − a ) を満たすようなc c c の集合
具体的な図を見るとわかりやすい。例えば、f ( x ) = ∣ x ∣ f(x)=|x| f ( x ) = ∣ x ∣ のx = 0 x=0 x = 0 における劣微分は[ − 1 , 1 ] [-1,1] [ − 1 , 1 ] である。
Loading image...
「x = 0 x=0 x = 0 を通り且つf ( x ) f(x) f ( x ) を超えない直線(動画の赤い直線)の傾きc c c 」が全てx = 0 x=0 x = 0 における劣微分となる。今回の場合は− 1 ≤ c ≤ 1 -1 \leq c \leq 1 − 1 ≤ c ≤ 1 の範囲。ちなみにc c c は劣勾配と呼ぶ。
以上を踏まえると、f ( x ) = ∣ x ∣ f(x)=|x| f ( x ) = ∣ x ∣ の微分を以下に定義できる。
f ′ ( x ) = { − 1 ( x < 0 ) [ − 1 , 1 ] ( x = 0 ) 1 ( x > 0 ) \begin{align}
f'(x) = \begin{cases}
-1 & (x<0) \\
[-1,1] & (x=0) \\
1 & (x>0) \\
\end{cases}
\end{align} f ′ ( x ) = ⎩ ⎨ ⎧ − 1 [ − 1 , 1 ] 1 ( x < 0 ) ( x = 0 ) ( x > 0 )
こうすることで、絶対値を含む関数での微分が考えられるようになる。
絶対値を含む関数の最適化
劣微分を使って、絶対値を含む下に凸な関数を最適化(最小化)する。
通常の下に凸な二次関数は、微分をして0になる点が最小値となる。しかし、絶対値を含む関数は素直に微分ができないので、劣微分を用いる。劣微分に0が含まれる点を求めることで最適解が得られる。例えばf ( x ) = ∣ x ∣ f(x)=|x| f ( x ) = ∣ x ∣ の場合、x = 0 x=0 x = 0 のときに0 ∈ f ′ ( x ) = [ − 1 , 1 ] 0\in f'(x)=[-1,1] 0 ∈ f ′ ( x ) = [ − 1 , 1 ] になるのでx = 0 x=0 x = 0 が最適解となる。
もう少し複雑な関数も考えてみよう。
f ( x ) = a x 2 + b x + c + d ∣ x ∣ a , d > 0 \begin{align}
f(x) = ax^2 + b&x + c + d|x| \\
a, d &> 0
\end{align} f ( x ) = a x 2 + b a , d x + c + d ∣ x ∣ > 0
二次関数に絶対値を足した。これは下に凸の関数である。この関数を最小化するx x x を求めよう。劣微分に0が含まれる点を求める。
まず微分する。
f ′ ( x ) = 2 a x + b + d ( ∣ x ∣ ) ′ \begin{align}
f'(x) = 2ax + b + d(|x|)'
\end{align} f ′ ( x ) = 2 a x + b + d ( ∣ x ∣ ) ′
ここで、
( ∣ x ∣ ) ′ = { − 1 ( x < 0 ) [ − 1 , 1 ] ( x = 0 ) 1 ( x > 0 ) \begin{align}
(|x|)' = \begin{cases}
-1 & (x<0) \\
[-1,1] & (x=0) \\
1 & (x>0) \\
\end{cases}
\end{align} ( ∣ x ∣ ) ′ = ⎩ ⎨ ⎧ − 1 [ − 1 , 1 ] 1 ( x < 0 ) ( x = 0 ) ( x > 0 )
より、
f ′ ( x ) = { 2 a x + b − d ( x < 0 ) [ 2 a x + b − d , 2 a x + b + d ] ( x = 0 ) 2 a x + b + d ( x > 0 ) \begin{align}
f'(x) = \begin{cases}
2ax + b - d & (x<0) \\
[2ax + b - d, 2ax + b + d] & (x=0) \\
2ax + b + d & (x>0) \\
\end{cases}
\end{align} f ′ ( x ) = ⎩ ⎨ ⎧ 2 a x + b − d [ 2 a x + b − d , 2 a x + b + d ] 2 a x + b + d ( x < 0 ) ( x = 0 ) ( x > 0 )
あとはこれが0になる(0を含む)x x x を求める。
まずx < 0 x<0 x < 0 の範囲を考える。f ′ ( x ) = 0 f'(x)=0 f ′ ( x ) = 0 を解く。
f ′ ( x ) = 0 2 a x + b − d = 0 x = − b − d 2 a \begin{align}
f'(x) &= 0 \\
2ax + b - d &= 0 \\
x &= -\frac{b-d}{2a}
\end{align} f ′ ( x ) 2 a x + b − d x = 0 = 0 = − 2 a b − d
そしてこのx x x が< 0 <0 < 0 となるのはb > d b>d b > d の時なので、b > d b>d b > d の時、x = − ( b − d ) / 2 a x=-(b-d)/2a x = − ( b − d ) /2 a がf ( x ) f(x) f ( x ) の最適解となる。
同じようにx > 0 x>0 x > 0 の範囲を考える。
f ′ ( x ) = 0 2 a x + b + d = 0 x = − b + d 2 a \begin{align}
f'(x) &= 0 \\
2ax + b + d &= 0 \\
x &= -\frac{b+d}{2a}
\end{align} f ′ ( x ) 2 a x + b + d x = 0 = 0 = − 2 a b + d
そしてこのx x x が> 0 >0 > 0 となるのはb < − d b<-d b < − d の時なので、b < − d b<-d b < − d の時、x = − ( b + d ) / 2 a x=-(b+d)/2a x = − ( b + d ) /2 a がf ( x ) f(x) f ( x ) の最適解となる。
最後に、どちらにも当てはまらない時、つまり− d ≤ b ≤ d -d\leq b\leq d − d ≤ b ≤ d の時を考えるが、x < 0 x<0 x < 0 でもx > 0 x>0 x > 0 でもないことが確定しているので、x = 0 x=0 x = 0 がf ( x ) f(x) f ( x ) の最適解となる。
まとめるとこうなる。
arg min x f ( x ) = { − b + d 2 a ( b < − d ) 0 ( − d ≤ b ≤ d ) − b − d 2 a ( b > d ) \begin{align}
\argmin_x f(x) = \begin{cases}
-\frac{b+d}{2a} & (b<-d) \\
0 & (-d\leq b\leq d) \\
-\frac{b-d}{2a} & (b>d) \\
\end{cases}
\end{align} x arg min f ( x ) = ⎩ ⎨ ⎧ − 2 a b + d 0 − 2 a b − d ( b < − d ) ( − d ≤ b ≤ d ) ( b > d )
例を見てみよう。a , b , c , d = 1 , 2 , 3 , 4 a, b, c, d = 1, 2, 3, 4 a , b , c , d = 1 , 2 , 3 , 4 とすると、− d ≤ b ≤ d -d\leq b\leq d − d ≤ b ≤ d になるのでx = 0 x=0 x = 0 が最適解となる。
Loading image...
またa , b , c , d = 1 , − 3 , 2 , 1 a, b, c, d = 1, -3, 2, 1 a , b , c , d = 1 , − 3 , 2 , 1 とすると、b < − d b<-d b < − d になるので
x = − b + d 2 a = − − 3 + 1 2 ⋅ 1 = 1 \begin{align}
x = -\frac{b+d}{2a}=-\frac{-3+1}{2\cdot1} = 1
\end{align} x = − 2 a b + d = − 2 ⋅ 1 − 3 + 1 = 1
が最適解となる。
Loading image...
ソフト閾値関数
Soft-Thresholding Function
以下の関数をソフト閾値関数や軟閾値作用素と呼ぶ。
S ( x , λ ) = sign ( x ) max ( ∣ x ∣ − λ , 0 ) = { x + λ ( x < − λ ) 0 ( − λ ≤ x ≤ λ ) x − λ ( x > λ ) λ ≥ 0 \begin{align}
S(x, \lambda)
&= \text{sign}(x) \max(|x|-\lambda, 0) \\
&= \begin{cases}
x+\lambda & (x<-\lambda) \\
0 & (-\lambda\leq x\leq \lambda) \\
x-\lambda & (x>\lambda) \\
\end{cases} \\
\lambda &\geq 0
\end{align} S ( x , λ ) λ = sign ( x ) max ( ∣ x ∣ − λ , 0 ) = ⎩ ⎨ ⎧ x + λ 0 x − λ ( x < − λ ) ( − λ ≤ x ≤ λ ) ( x > λ ) ≥ 0
Pythonで書くとこう。
def soft_thresholding ( x , threshold ):
return np.sign(x) * np.maximum(np.abs(x) - threshold, 0 )
λ = 1 \lambda=1 λ = 1 とすると、こんなグラフ。
Loading image...
λ \lambda λ を大きくすると0になる範囲が広がる。
この関数を用いると、先の解を以下のように表せる。
arg min x f ( x ) = { − b + d 2 a ( b < − d ) 0 ( − d ≤ b ≤ d ) − b − d 2 a ( b > d ) = S ( − b 2 a , d 2 a ) \begin{align}
\argmin_x f(x)
&= \begin{cases}
-\frac{b+d}{2a} & (b<-d) \\
0 & (-d\leq b\leq d) \\
-\frac{b-d}{2a} & (b>d) \\
\end{cases} \\
&= S\left(-\frac{b}{2a}, \frac{d}{2a}\right)
\end{align} x arg min f ( x ) = ⎩ ⎨ ⎧ − 2 a b + d 0 − 2 a b − d ( b < − d ) ( − d ≤ b ≤ d ) ( b > d ) = S ( − 2 a b , 2 a d )
Lasso回帰の最適化
ではいよいよLasso回帰の最適化を行おう。目的関数は以下とする。
J Las ( w ) = 1 2 N ∥ y − X w ∥ 2 2 + α ∥ w ∥ 1 \begin{align}
J_\text{Las}(\bm w) = \frac{1}{2N}\| \bm y - X \bm w \|_2^2 + \alpha \| \bm w \|_1
\end{align} J Las ( w ) = 2 N 1 ∥ y − X w ∥ 2 2 + α ∥ w ∥ 1
scikit-learn に倣い、一項目をデータ数N N N で割った。どうせα \alpha α で比重を調整できるので定数を掛けても問題ない。N N N で割っておくとα \alpha α に関する議論でN N N をあまり意識しなくてよくなる(データ数に応じてα \alpha α を調整する必要がなくなる)。
ちなみにさっきからずっと誤差項に1 / 2 1/2 1/2 を掛けているのは微分したときにきれいになるからってだけ。あと一応後の章で出てくる正規分布との兼ね合いもある。分散1の正規分布を仮定すると1 / 2 1/2 1/2 だけ残るのでそれと一致して分かりやすいよねって感じ。
座標降下法を用いて最適化を行う。ある一つのパラメータw i w_i w i に着目し、それ以外のパラメータw j ( i ≠ j ) w_j\,(i\neq j) w j ( i = j ) を固定した状態でw i w_i w i を最小化する。
目的関数をw i w_i w i について整理しよう。
J Las ( w i ) = 1 2 N ∥ y − X w ∥ 2 2 + α ∥ w ∥ 1 = 1 2 N ∑ n = 1 N ( y ( n ) − w ⊤ x ( n ) ) 2 + α ∥ w ∥ 1 = 1 2 N ∑ n = 1 N ( y ( n ) − w \ i ⊤ x \ i ( n ) − w i x i ( n ) ) 2 + α ( ∥ w \ i ∥ 1 + ∣ w i ∣ ) \begin{align}
J_\text{Las}(w_i)
&= \frac{1}{2N} \| \bm y - X \bm w \|_2^2 + \alpha \| \bm w \|_1 \\
&= \frac{1}{2N} \sum_{n=1}^N \left( y^{(n)} - \bm w^\top\bm x^{(n)} \right)^2 + \alpha \| \bm w \|_1 \\
&= \frac{1}{2N} \sum_{n=1}^N \left( y^{(n)} - \bm w_{\backslash i}^\top\bm x_{\backslash i}^{(n)} - w_ix_i^{(n)} \right)^2 + \alpha(\| \bm w_{\backslash i} \|_1 + |w_i|) \\
\end{align} J Las ( w i ) = 2 N 1 ∥ y − X w ∥ 2 2 + α ∥ w ∥ 1 = 2 N 1 n = 1 ∑ N ( y ( n ) − w ⊤ x ( n ) ) 2 + α ∥ w ∥ 1 = 2 N 1 n = 1 ∑ N ( y ( n ) − w \ i ⊤ x \ i ( n ) − w i x i ( n ) ) 2 + α ( ∥ w \ i ∥ 1 + ∣ w i ∣ )
x \ i \bm x_{\backslash i} x \ i はx \bm x x のi i i 番目の要素x i x_i x i 以外を並べたベクトル。今後のためにi i i とそれ以外で分けた。
ここで以下を定義する。
r \ i ( n ) = y ( n ) − w \ i ⊤ x \ i ( n ) r \ i = ( r \ i ( 1 ) , r \ i ( 2 ) , ⋯ , r \ i ( N ) ) ⊤ ∈ R N x i = ( x i ( 1 ) , x i ( 2 ) , ⋯ , x i ( N ) ) ⊤ ∈ R N \begin{align}
r_{\backslash i}^{(n)} &= y^{(n)} - \bm w_{\backslash i}^\top\bm x_{\backslash i}^{(n)} \\
\bm r_{\backslash i} &= (r_{\backslash i}^{(1)}, r_{\backslash i}^{(2)}, \cdots, r_{\backslash i}^{(N)})^\top \in \R^N \\
\bm x_i &= (x_i^{(1)}, x_i^{(2)}, \cdots, x_i^{(N)})^\top \in \R^N
\end{align} r \ i ( n ) r \ i x i = y ( n ) − w \ i ⊤ x \ i ( n ) = ( r \ i ( 1 ) , r \ i ( 2 ) , ⋯ , r \ i ( N ) ) ⊤ ∈ R N = ( x i ( 1 ) , x i ( 2 ) , ⋯ , x i ( N ) ) ⊤ ∈ R N
r \ i ( n ) r_{\backslash i}^{(n)} r \ i ( n ) はx ( n ) \bm x^{(n)} x ( n ) のi i i 番目の特徴量x i ( n ) x_i^{(n)} x i ( n ) を考慮せずに予測した値の誤差。r \ i \bm r_{\backslash i} r \ i はそれらを並べたベクトル。x i \bm x_i x i はi i i 番目の特徴量を並べたベクトル。これで変形を進める。
= 1 2 N ∑ n = 1 N ( r \ i ( n ) − w i x i ( n ) ) 2 + α ( ∥ w \ i ∥ 1 + ∣ w i ∣ ) = 1 2 N ∑ n = 1 N ( ( r \ i ( n ) ) 2 − 2 r \ i ( n ) w i x i ( n ) + ( w i x i ( n ) ) 2 ) + α ( ∥ w \ i ∥ 1 + ∣ w i ∣ ) = 1 2 N ∑ n = 1 N ( − 2 r \ i ( n ) w i x i ( n ) + ( w i x i ( n ) ) 2 ) + α ∣ w i ∣ + const \begin{align}
&= \frac{1}{2N} \sum_{n=1}^N \left( r_{\backslash i}^{(n)} - w_ix_i^{(n)} \right)^2 + \alpha(\| \bm w_{\backslash i} \|_1 + |w_i|) \\
&= \frac{1}{2N}\sum_{n=1}^N \left( (r_{\backslash i}^{(n)})^2 - 2r_{\backslash i}^{(n)}w_ix_i^{(n)} + (w_ix_i^{(n)})^2 \right) + \alpha(\| \bm w_{\backslash i} \|_1 + |w_i|) \\
&= \frac{1}{2N}\sum_{n=1}^N \left( -2r_{\backslash i}^{(n)}w_ix_i^{(n)} + (w_ix_i^{(n)})^2 \right) + \alpha |w_i| + \text{const} \\
\end{align} = 2 N 1 n = 1 ∑ N ( r \ i ( n ) − w i x i ( n ) ) 2 + α ( ∥ w \ i ∥ 1 + ∣ w i ∣ ) = 2 N 1 n = 1 ∑ N ( ( r \ i ( n ) ) 2 − 2 r \ i ( n ) w i x i ( n ) + ( w i x i ( n ) ) 2 ) + α ( ∥ w \ i ∥ 1 + ∣ w i ∣ ) = 2 N 1 n = 1 ∑ N ( − 2 r \ i ( n ) w i x i ( n ) + ( w i x i ( n ) ) 2 ) + α ∣ w i ∣ + const
w i w_i w i に関係のない項を定数としてconst \text{const} const にまとめた。さらに変形を進めて
= 1 2 N ( − 2 w i ∑ n = 1 N r \ i ( n ) x i ( n ) + w i 2 ∑ n = 1 N ( x i ( n ) ) 2 ) + α ∣ w i ∣ + const = 1 2 N ( − 2 w i r \ i ⊤ x i + w i 2 ∥ x i ∥ 2 2 ) + α ∣ w i ∣ + const = − 1 N w i r \ i ⊤ x i + 1 2 N w i 2 ∥ x i ∥ 2 2 + α ∣ w i ∣ + const \begin{align}
&= \frac{1}{2N} \left( -2w_i \sum_{n=1}^N r_{\backslash i}^{(n)}x_i^{(n)} + w_i^2 \sum_{n=1}^N (x_i^{(n)})^2 \right) + \alpha |w_i| + \text{const} \\
&= \frac{1}{2N} \left( -2w_i\bm r_{\backslash i}^\top\bm x_i + w_i^2\|\bm x_i\|_2^2 \right) + \alpha |w_i| + \text{const} \\
&= -\frac{1}{N}w_i\bm r_{\backslash i}^\top\bm x_i + \frac{1}{2N}w_i^2\|\bm x_i\|_2^2 + \alpha |w_i| + \text{const} \\
\end{align} = 2 N 1 ( − 2 w i n = 1 ∑ N r \ i ( n ) x i ( n ) + w i 2 n = 1 ∑ N ( x i ( n ) ) 2 ) + α ∣ w i ∣ + const = 2 N 1 ( − 2 w i r \ i ⊤ x i + w i 2 ∥ x i ∥ 2 2 ) + α ∣ w i ∣ + const = − N 1 w i r \ i ⊤ x i + 2 N 1 w i 2 ∥ x i ∥ 2 2 + α ∣ w i ∣ + const
w i 2 w_i^2 w i 2 について整理すると
J Las ( w i ) = 1 2 N ∥ x i ∥ 2 2 w i 2 − 1 N r \ i ⊤ x i w i + const + α ∣ w i ∣ \begin{align}
J_\text{Las}(w_i) = \frac{1}{2N} \| \bm x_i \|_2^2w_i^2 - \frac{1}{N}\bm r_{\backslash i}^\top\bm x_iw_i + \text{const} + \alpha|w_i|
\end{align} J Las ( w i ) = 2 N 1 ∥ x i ∥ 2 2 w i 2 − N 1 r \ i ⊤ x i w i + const + α ∣ w i ∣
前章で解いたf ( x ) = a x 2 + b x + c + d ∣ x ∣ ( a , d > 0 ) f(x)=ax^2+bx+c+d|x|\quad(a, d > 0) f ( x ) = a x 2 + b x + c + d ∣ x ∣ ( a , d > 0 ) と同じ形になった。この解はS ( − b 2 a , d 2 a ) S\left(-\frac{b}{2a}, \frac{d}{2a}\right) S ( − 2 a b , 2 a d ) であるため、
a = 1 2 N ∥ x i ∥ 2 2 b = − 1 N r \ i ⊤ x i d = α \begin{align}
a &= \frac{1}{2N} \| \bm x_i \|_2^2 \\
b &= -\frac{1}{N}\bm r_{\backslash i}^\top\bm x_i \\
d &= \alpha \\
\end{align} a b d = 2 N 1 ∥ x i ∥ 2 2 = − N 1 r \ i ⊤ x i = α
を代入すると
arg min w i J Las ( w i ) = S ( r \ i ⊤ x i ∥ x i ∥ 2 2 , N α ∥ x i ∥ 2 2 ) \begin{align}
\argmin_{w_i} J_\text{Las}(w_i) = S \left( \frac{\bm r_{\backslash i}^\top\bm x_i}{\|\bm x_i\|_2^2}, \frac{N\alpha}{\|\bm x_i\|_2^2} \right)
\end{align} w i arg min J Las ( w i ) = S ( ∥ x i ∥ 2 2 r \ i ⊤ x i , ∥ x i ∥ 2 2 N α )
が得られる。これが各ステップにおけるw i w_i w i の更新先となる。
また、バイアス項に対応するパラメータは正則化をしないので、解が変わる。w 0 w_0 w 0 をバイアスとし、正則化項を無視すると、
J Las ( w 0 ) = 1 2 N ∥ x 0 ∥ 2 2 w 0 2 − 1 N r \ 0 ⊤ x 0 w 0 + const \begin{align}
J_\text{Las}(w_0) = \frac{1}{2N} \|\bm x_0 \|_2^2w_0^2 - \frac{1}{N}\bm r_{\backslash 0}^\top\bm x_0w_0 + \text{const}
\end{align} J Las ( w 0 ) = 2 N 1 ∥ x 0 ∥ 2 2 w 0 2 − N 1 r \0 ⊤ x 0 w 0 + const
となるので、これを微分し、
∂ J Las ∂ w 0 = 1 N ∥ x 0 ∥ 2 2 w 0 − 1 N r \ 0 ⊤ x 0 \begin{align}
\frac{\partial J_\text{Las}}{\partial w_0} = \frac{1}{N}\|\bm x_0\|_2^2w_0 - \frac{1}{N}\bm r_{\backslash 0}^\top\bm x_0
\end{align} ∂ w 0 ∂ J Las = N 1 ∥ x 0 ∥ 2 2 w 0 − N 1 r \0 ⊤ x 0
= 0 =0 = 0 とおいてw 0 w_0 w 0 について解く。
w 0 = r \ 0 ⊤ x 0 ∥ x 0 ∥ 2 2 \begin{align}
w_0 = \frac{\bm r_{\backslash 0}^\top\bm x_0}{\|\bm x_0\|_2^2}
\end{align} w 0 = ∥ x 0 ∥ 2 2 r \0 ⊤ x 0
w 0 w_0 w 0 をバイアスとする場合、x 0 = ( 1 , 1 , ⋯ , 1 ) ⊤ ∈ R N \bm x_0=(1, 1, \cdots, 1)^\top\in\R^N x 0 = ( 1 , 1 , ⋯ , 1 ) ⊤ ∈ R N となるので、
w 0 = ∑ n = 1 N r \ 0 ( n ) N = r \ 0 ˉ \begin{align}
w_0
&= \frac{\sum_{n=1}^Nr_{\backslash 0}^{(n)}}{N} \\
&= \bar{r_{\backslash 0}} \\
\end{align} w 0 = N ∑ n = 1 N r \0 ( n ) = r \0 ˉ
となる。バイアス項を無視した予測値と正解の差の平均だね。
以上で全てのパラメータの更新先が得られた。最後にLasso回帰の最適化手順をまとめる。
w \bm w w を初期化
w 0 w_0 w 0 を最適化: w 0 ← r \ 0 ˉ w_0 \leftarrow \bar{r_{\backslash 0}} w 0 ← r \0 ˉ
w 1 , w 2 , ⋯ , w m w_1, w_2, \cdots, w_m w 1 , w 2 , ⋯ , w m を最適化: w i ← S ( r \ i ⊤ x i ∥ x i ∥ 2 2 , N α ∥ x i ∥ 2 2 ) w_i \leftarrow S \left( \frac{\bm r_{\backslash i}^\top\bm x_i}{\|\bm x_i\|_2^2}, \frac{N\alpha}{\|\bm x_i\|_2^2} \right) w i ← S ( ∥ x i ∥ 2 2 r \ i ⊤ x i , ∥ x i ∥ 2 2 N α )
収束するまで2, 3を繰り返す
Python実装
実装してみよう。
import numpy as np
class Lasso :
def __init__ ( self , alpha : float = 1 ., max_iter : int = 1000 ):
self .alpha = alpha
self .max_iter = max_iter # 繰り返しの上限
self .weights = None
@ staticmethod
def soft_thresholding ( x , threshold ):
return np.sign(x) * np.maximum(np.abs(x) - threshold, 0 )
def fit ( self , X , y ):
X = np.insert(X, 0 , 1 , axis = 1 )
n_samples, n_features = X.shape
self .weights = np.zeros(n_features) # パラメータの初期化
for _ in range ( self .max_iter): # 上限まで繰り返し
# w0の更新
self .weights[ 0 ] = np.mean(y - X[:, 1 :] @ self .weights[ 1 :])
# w1, w2, ..., wmの更新
for i in range ( 1 , n_features):
r_i = y - np.delete(X, i, 1 ) @ np.delete( self .weights, i)
x_i = X[:, i]
norm = np.linalg.norm(x_i) ** 2
self .weights[i] = self .soft_thresholding(
r_i @ x_i / norm, n_samples * self .alpha / norm
)
def predict ( self , X ):
X = np.insert(X, 0 , 1 , axis = 1 )
return X @ self .weights
収束するまでと書いたが、コードをシンプルにするため、繰り返しの上限max_iter
を導入し、上限に達したら終了するようにした。
適当なデータで学習させてみる。
from sklearn.datasets import load_diabetes
X, y = load_diabetes( return_X_y = True )
model = Lasso( alpha = 1 .)
model.fit(X, y)
print ( " \n " .join(model.weights.astype( str ).tolist()))
152.133484162896
0.0
-0.0
367.70162582143126
6.30970264417499
0.0
0.0
-0.0
0.0
307.60214746219583
0.0
sklearnとも一致する。
from sklearn.linear_model import Lasso as LassoSklearn
model = LassoSklearn( alpha = 1 ., tol = 0 ., max_iter = 1000 )
model.fit(X, y)
weights = np.append([model.intercept_], model.coef_)
print ( " \n " .join(weights.astype( str ).tolist()))
152.133484162896
0.0
-0.0
367.70162582143115
6.309702644174996
0.0
0.0
-0.0
0.0
307.6021474621959
0.0
微小な誤差が生じているが、細かい実装方法が違うだけかな。
Lasso回帰の性質
最適化から離れた、ちょっとしたLasso回帰の小話
スパース推定
前章で得られた解を見てみると、多くの要素が0になっていることが分かる。このような多くの要素が0であるベクトルはスパースなベクトルや疎ベクトルと呼ばれる。反対に多くの要素が0でないベクトルは密ベクトルと呼ばれる。
こういったスパースな解を得ることはスパース推定やスパースモデリングと呼ばれる。スパース推定によって多くのデータの中から重要なデータのみを取り出すことができる。Lasso回帰では、複数の特徴量の中から目的変数の予測に寄与するものだけを取り出すことができる。Lasso回帰は特徴量選択という面で非常に解釈性の高いモデルと言える。
なぜLasso回帰によってスパースな解が得られるかはパラメータの更新式を見ると分かりやすい。
w i ← S ( r \ i ⊤ x i ∥ x i ∥ 2 2 , N α ∥ x i ∥ 2 2 ) \begin{align}
w_i \leftarrow S \left( \frac{\bm r_{\backslash i}^\top\bm x_i}{\|\bm x_i\|_2^2}, \frac{N\alpha}{\|\bm x_i\|_2^2} \right)
\end{align} w i ← S ( ∥ x i ∥ 2 2 r \ i ⊤ x i , ∥ x i ∥ 2 2 N α )
ソフト閾値関数S ( x , λ ) S(x, \lambda) S ( x , λ ) はx x x の絶対値が閾値λ \lambda λ を超えない場合に0となる関数である。α \alpha α を大きくすると閾値が大きくなり、結果、w i w_i w i が0になることが多くなる。
ちなみに、L1ではなくL0ノルムによる正則化を行う手法もある。L0ノルムはベクトルの非ゼロの要素の数を表すため、より厳密なスパース推定が可能になる。ただ計算量的に最適化が困難であるため、多くの場合はL1ノルムによる正則化が用いられる。
ベイズ的解釈
線形回帰の最小二乗法は誤差に正規分布を仮定した最尤推定と解釈できる。実際にその仮定で対数尤度の計算を進めると
p ( y ∣ x ; w ) = N ( y ; w ⊤ x , 1 ) ln p ( y ∣ X ; w ) = ∑ n ln p ( y ( n ) ∣ x ( n ) ; w ) = ∑ n ln ( 1 2 π exp ( − ( y ( n ) − w ⊤ x ( n ) ) 2 2 ) ) = − 1 2 ∑ n ( y ( n ) − w ⊤ x ( n ) ) 2 + const = − 1 2 ∥ y − X w ∥ 2 2 + const \begin{align}
p(y|\bm x; \bm w) &= \mathcal N(y; \bm w^\top\bm x,1) \\
\ln p(\bm y| X; \bm w)
&= \sum_n \ln p(y^{(n)}| \bm x^{(n)}; \bm w) \\
&= \sum_n \ln \left( \frac{1}{\sqrt{2\pi}} \exp \left( -\frac{(y^{(n)} - \bm w^\top \bm x^{(n)})^2}{2} \right) \right) \\
&= -\frac{1}{2} \sum_n (y^{(n)} - \bm w^\top \bm x^{(n)})^2 + \text{const} \\
&= -\frac{1}{2} \| \bm y - X\bm w \|_2^2 + \text{const}
\end{align} p ( y ∣ x ; w ) ln p ( y ∣ X ; w ) = N ( y ; w ⊤ x , 1 ) = n ∑ ln p ( y ( n ) ∣ x ( n ) ; w ) = n ∑ ln ( 2 π 1 exp ( − 2 ( y ( n ) − w ⊤ x ( n ) ) 2 ) ) = − 2 1 n ∑ ( y ( n ) − w ⊤ x ( n ) ) 2 + const = − 2 1 ∥ y − X w ∥ 2 2 + const
負の二乗和誤差が出てくる。この最大化は二乗和誤差の最小化と同じである。ではここに正則化項を足すことは何を意味するか。
ここに正則化項を足すことは、確率モデルに事前分布p ( w ) p(\bm w) p ( w ) を仮定することと同じ意味になる。例えばL2正則化はパラメータの事前分布として平均0の正規分布を仮定することと解釈できる。
p ( w ) = N ( w ; 0 , 1 λ I ) ln p ( w ∣ X , y ) = ln p ( y ∣ X ; w ) + ln p ( w ) = − 1 2 ∥ y − X w ∥ 2 2 − 1 2 λ ∥ w ∥ 2 2 + const \begin{align}
p(\bm w) &= \mathcal N(\bm w; \bm 0, \frac{1}{\lambda} I) \\
\ln p(\bm w| X, \bm y)
&= \ln p(\bm y| X; \bm w) + \ln p(\bm w) \\
&= -\frac{1}{2} \| \bm y - X\bm w \|_2^2 - \frac{1}{2}\lambda \| \bm w\|_2^2 + \text{const}
\end{align} p ( w ) ln p ( w ∣ X , y ) = N ( w ; 0 , λ 1 I ) = ln p ( y ∣ X ; w ) + ln p ( w ) = − 2 1 ∥ y − X w ∥ 2 2 − 2 1 λ ∥ w ∥ 2 2 + const
Lasso回帰はL1正則化を行うが、これはパラメータの事前分布としてラプラス分布を仮定していると解釈できる。
p ( w ) = Lap ( w ; 0 , λ ) ln p ( w ∣ X , y ) = ln p ( y ∣ X ; w ) + ln p ( w ) = − 1 2 ∥ y − X w ∥ 2 2 − λ ∥ w ∥ 1 + const \begin{align}
p(\bm w) &= \text{Lap}(\bm w; \bm 0, \lambda) \\
\ln p(\bm w| X, \bm y)
&= \ln p(\bm y| X; \bm w) + \ln p(\bm w) \\
&= -\frac{1}{2} \| \bm y - X\bm w \|_2^2 - \lambda \| \bm w\|_1 + \text{const} \\
\end{align} p ( w ) ln p ( w ∣ X , y ) = Lap ( w ; 0 , λ ) = ln p ( y ∣ X ; w ) + ln p ( w ) = − 2 1 ∥ y − X w ∥ 2 2 − λ ∥ w ∥ 1 + const
このあたりの細かい話は別の記事に書いたので、興味があればぜひ: 線形回帰における最尤推定・MAP推定・ベイズ推定
ElasticNet
Ridge回帰とLasso回帰を組み合わせたモデル。L1正則化とL2正則化を両方行う。次の目的関数を最適化する。
J EN ( w ) = 1 2 N ∥ y − X w ∥ 2 2 + α ( β ∥ w ∥ 1 + ( 1 − β ) 1 2 ∥ w ∥ 2 2 ) \begin{align}
J_\text{EN}(\bm w) = \frac{1}{2N} \| \bm y - \bm X \bm w \|_2^2 + \alpha \left( \beta \| \bm w \|_1 + (1-\beta) \frac{1}{2} \| \bm w \|_2^2 \right)
\end{align} J EN ( w ) = 2 N 1 ∥ y − X w ∥ 2 2 + α ( β ∥ w ∥ 1 + ( 1 − β ) 2 1 ∥ w ∥ 2 2 )
β \beta β はL1正則化の割合を表すハイパーパラメータ。β = 1 \beta=1 β = 1 とするとLasso回帰、β = 0 \beta=0 β = 0 とするとRidge回帰と一致する。
解き方はLassoと一緒。絶対値が含まれているので解析的には解けない。ここでも座標降下法を使う。Lasso同様、ある一つのパラメータw i w_i w i に着目した目的関数J ( w i ) J(w_i) J ( w i ) が必要。上の式を変形していっても良いが、面倒くさいので、Lasso回帰とどこが変わったかを考える。
Lasso回帰の目的関数は以下であった。
J Las ( w i ) = 1 2 N ∥ x i ∥ 2 2 w i 2 − 1 N r \ i ⊤ x i w i + const + α ∣ w i ∣ \begin{align}
J_\text{Las}(w_i) = \frac{1}{2N} \| \bm x_i \|_2^2w_i^2 - \frac{1}{N} \bm r_{\backslash i}^\top\bm x_iw_i + \text{const} + \alpha|w_i|
\end{align} J Las ( w i ) = 2 N 1 ∥ x i ∥ 2 2 w i 2 − N 1 r \ i ⊤ x i w i + const + α ∣ w i ∣
各項の係数がどう変化するかを考えよう。L2正則化は各パラメータw i w_i w i の二乗を足すので、増えた分を二乗の項の係数に足せば良い。つまりハイパーパラメータと定数を掛けたα ( 1 − β ) / 2 \alpha(1-\beta)/2 α ( 1 − β ) /2 を足せば良い。またL1ノルムには新たなパラメータβ \beta β が追加で掛けられる。これらをまとめるとこうなる。
J EN ( w i ) = 1 2 N ( ∥ x i ∥ 2 2 + N α ( 1 − β ) ) w i 2 − 1 N r \ i ⊤ x i w i + const + α β ∣ w i ∣ \begin{align}
J_\text{EN}(w_i) = \frac{1}{2N}(\|\bm x_i\|_2^2 + N\alpha(1-\beta))w_i^2 - \frac{1}{N}\bm r_{\backslash i}^\top\bm x_iw_i + \text{const} + \alpha\beta|w_i|
\end{align} J EN ( w i ) = 2 N 1 ( ∥ x i ∥ 2 2 + N α ( 1 − β )) w i 2 − N 1 r \ i ⊤ x i w i + const + α β ∣ w i ∣
ここからの解き方は同じ。
a = 1 2 N ( ∥ x i ∥ 2 2 + N α ( 1 − β ) ) b = − 1 N r \ i ⊤ x i d = α β \begin{align}
a &= \frac{1}{2N}\big(\|\bm x_i\|_2^2 + N\alpha(1-\beta)\big) \\
b &= -\frac{1}{N}\bm r_{\backslash i}^\top\bm x_i \\
d &= \alpha\beta
\end{align} a b d = 2 N 1 ( ∥ x i ∥ 2 2 + N α ( 1 − β ) ) = − N 1 r \ i ⊤ x i = α β
とすると、解はS ( − b 2 a , d 2 a ) S\left(-\frac{b}{2a}, \frac{d}{2a}\right) S ( − 2 a b , 2 a d ) なので
arg min w i J EN ( w i ) = S ( r \ i ⊤ x i ∥ x i ∥ 2 2 + N α ( 1 − β ) , N α β ∥ x i ∥ 2 2 + N α ( 1 − β ) ) \begin{align}
\argmin_{w_i} J_\text{EN}(w_i) = S \left( \frac{\bm r_{\backslash i}^\top\bm x_i}{\|\bm x_i\|_2^2 + N\alpha(1-\beta)}, \frac{N\alpha\beta}{\|\bm x_i\|_2^2 + N\alpha(1-\beta)} \right)
\end{align} w i arg min J EN ( w i ) = S ( ∥ x i ∥ 2 2 + N α ( 1 − β ) r \ i ⊤ x i , ∥ x i ∥ 2 2 + N α ( 1 − β ) N α β )
が得られる。
またバイアスについては正則化項を考慮しないため先と同じく
w 0 = r \ 0 ˉ \begin{align}
w_0 = \bar{r_{\backslash 0}}
\end{align} w 0 = r \0 ˉ
となる。
実装してみよう。
class ElasticNet :
def __init__ (
self ,
alpha : float = 1 .,
beta : float = 0.5 ,
max_iter : int = 1000
):
self .alpha = alpha
self .beta = beta
self .max_iter = max_iter
self .weights = None
@ staticmethod
def soft_thresholding ( x , threshold ):
return np.sign(x) * np.maximum(np.abs(x) - threshold, 0 )
def fit ( self , X , y ):
X = np.insert(X, 0 , 1 , axis = 1 )
n_samples, n_features = X.shape
self .weights = np.zeros(n_features)
for _ in range ( self .max_iter):
self .weights[ 0 ] = np.mean(y - X[:, 1 :] @ self .weights[ 1 :])
for i in range ( 1 , n_features):
r_i = y - np.delete(X, i, 1 ) @ np.delete( self .weights, i)
x_i = X[:, i]
norm = np.linalg.norm(x_i) ** 2
norm += n_samples * self .alpha * ( 1 - self .beta) # 追加
self .weights[i] = self .soft_thresholding(
r_i @ x_i / norm,
n_samples * self .alpha * self .beta / norm # 変更
)
def predict ( self , X ):
X = np.insert(X, 0 , 1 , axis = 1 )
return X @ self .weights
更新式が少しだけ変わった。分母はノルムだけではないからnorm
という変数名を変えるか迷ったけど、normには規格という意味もあるみたいなのでそのままにした。
model = ElasticNet( alpha = 1 ., beta = 0.5 )
model.fit(X, y)
print ( " \n " .join(model.weights.astype( str ).tolist()))
152.13348416289594
0.3590175634148627
0.0
3.259766998005527
2.2043402383839803
0.5286453997828984
0.2509350904357106
-1.8613631921210814
2.1144540777001035
3.105834685472744
1.7698510183435376
sklearnとも一致する。
from sklearn.linear_model import ElasticNet as ElasticNetSklearn
model = ElasticNetSklearn( alpha = 1 ., l1_ratio = 0.5 , tol = 0 ., max_iter = 1000 )
model.fit(X, y)
weights = np.append([model.intercept_], model.coef_)
print ( " \n " .join(weights.astype( str ).tolist()))
152.13348416289594
0.3590175634148638
0.0
3.2597669980055266
2.204340238383981
0.5286453997828972
0.2509350904357103
-1.8613631921210825
2.1144540777001053
3.105834685472747
1.7698510183435392
オワリ
おつ。