icon

Lasso回帰の理論と実装

公開: 2025-05-25 / 最終更新: 2025-06-29
機械学習線形回帰LassoElasticNet

Lasso回帰の数式的な理解とPythonでのスクラッチ実装。ElasticNetも。

線形回帰と正則化

Lasso回帰を学ぶ前に軽く予習をしよう。

線形回帰では入力xRm\bm x\in\R^mと重みパラメータwRm\bm w\in\R^mの内積をとることでターゲットyyを予測する。

y^=xw\begin{align} \hat y = \bm x^\top\bm w \end{align}

入力とそれに対応するターゲットがnn組得られたとする。これらを

  • XRn×mX\in\R^{n\times m}
  • yRn\bm y\in\R^n

と表しておく。

最小二乗法による最適化では、予測値と正解の二乗和誤差が最小になるようなパラメータw\bm wを求める。これは目的関数を微分して=0=0とおくことで解析的に求められる。

JLR(w)=12yXw22JLRw=XXwXyw^=(XX)1Xy\begin{align} J_\text{LR}(\bm w) &= \frac{1}{2} \| \bm y - X\bm w \|_2^2 \\ \frac{\partial J_\text{LR}}{\partial \bm w} &= X^\top X\bm w - X^\top\bm y \\ \hat{\bm w} &= (X^\top X)^{-1}X^\top\bm y \end{align}

過学習を防ぐため、目的関数に正則化項を足すことがある。パラメータの二乗和で正則化を行う線形回帰はRidge回帰と呼ばれる。二乗和はL2ノルムで表せるのでL2正則化とも呼ばれる。この場合も同様の流れで解析解が得られる。

JRid(w)=12yXw22+12αw22JRidw=XXwXy+αww^=(XX+αI)1Xy\begin{align} J_\text{Rid}(\bm w) &= \frac{1}{2} \| \bm y - X\bm w \|_2^2 + \frac{1}{2}\alpha\|\bm w\|_2^2 \\ \frac{\partial J_\text{Rid}}{\partial \bm w} &= X^\top X\bm w - X^\top\bm y + \alpha\bm w \\ \hat{\bm w} &= (X^\top X + \alpha I)^{-1}X^\top\bm y \end{align}

Lasso回帰

Least Absolute Shrinkage and Selection Operator

Lasso回帰では正則化項にパラメータの絶対値の和を用いる。絶対値の和はL1ノルムで表せるのでL1正則化とも呼ばれる。

JLas(w)=12yXw22+αw1w1=iwi\begin{align} J_\text{Las}(\bm w) = \frac{1}{2} \| \bm y - &X \bm w \|_2^2 + \alpha \| \bm w \|_1 \\ \| \bm w \|_1 &= \sum_i |w_i| \end{align}

Lasso回帰はRidge回帰と異なり解析解が得られない。絶対値が含まれているため微分が出来ないためである。本稿ではそんなLasso回帰の最適化について、その理論と実装方法をまとめる。

座標降下法

Coordinate Descent

最適化対象の変数を一つずつ順番に最適化する手法。Lasso回帰の最適化においてもこの手法がよく用いられる。

例えばθ=(θ1,θ2,,θm)\bm\theta = (\theta_1,\theta_2,\ldots,\theta_m)^\topの最適化を行うとき、θ\bm\thetaを一度に最適化するのではなく、個別の変数θi\theta_ii=1i=1からmmまで順に最適化していく。θi\theta_iの最適化はそれ以外の変数θj(ji)\theta_j\,(j\neq i)を固定して行う。これを収束するまで繰り返す。

勾配降下法同様、目的関数が凸関数でないと局所解に陥る可能性がある。一方で勾配降下法とは異なり、目的関数が微分可能でなくとも適用できるという汎用性がある。微分は出来ないが、ある一つの変数に着目した場合の最適解は求められる、といった少し特殊なケースで活躍する。

勾配降下法と座標降下法の違いはこんなイメージ。座標降下法では各ステップにおいて一つの変数しか更新されないため、いずれかの軸に対して並行に移動する。

Loading image...

劣微分

Subgradient

微分の定義を少し拡張したようなもの。微分の一般化とも見られる。関数f(x)f(x)x=ax=aにおける劣微分を次のように定義する。

任意のxxに対してf(x)f(a)+c(xa)f(x)\geq f(a)+c(x-a)を満たすようなccの集合

具体的な図を見るとわかりやすい。例えば、f(x)=xf(x)=|x|x=0x=0における劣微分は[1,1][-1,1]である。

Loading image...

x=0x=0を通り且つf(x)f(x)を超えない直線(動画の赤い直線)の傾きcc」が全てx=0x=0における劣微分となる。今回の場合は1c1-1 \leq c \leq 1の範囲。ちなみにccは劣勾配と呼ぶ。

以上を踏まえると、f(x)=xf(x)=|x|の微分を以下に定義できる。

f(x)={1(x<0)[1,1](x=0)1(x>0)\begin{align} f'(x) = \begin{cases} -1 & (x<0) \\ [-1,1] & (x=0) \\ 1 & (x>0) \\ \end{cases} \end{align}

こうすることで、絶対値を含む関数での微分が考えられるようになる。

絶対値を含む関数の最適化

劣微分を使って、絶対値を含む下に凸な関数を最適化(最小化)する。

通常の下に凸な二次関数は、微分をして0になる点が最小値となる。しかし、絶対値を含む関数は素直に微分ができないので、劣微分を用いる。劣微分に0が含まれる点を求めることで最適解が得られる。例えばf(x)=xf(x)=|x|の場合、x=0x=0のときに0f(x)=[1,1]0\in f'(x)=[-1,1]になるのでx=0x=0が最適解となる。

もう少し複雑な関数も考えてみよう。

f(x)=ax2+bx+c+dxa,d>0\begin{align} f(x) = ax^2 + b&x + c + d|x| \\ a, d &> 0 \end{align}

二次関数に絶対値を足した。これは下に凸の関数である。この関数を最小化するxxを求めよう。劣微分に0が含まれる点を求める。

まず微分する。

f(x)=2ax+b+d(x)\begin{align} f'(x) = 2ax + b + d(|x|)' \end{align}

ここで、

(x)={1(x<0)[1,1](x=0)1(x>0)\begin{align} (|x|)' = \begin{cases} -1 & (x<0) \\ [-1,1] & (x=0) \\ 1 & (x>0) \\ \end{cases} \end{align}

より、

f(x)={2ax+bd(x<0)[2ax+bd,2ax+b+d](x=0)2ax+b+d(x>0)\begin{align} f'(x) = \begin{cases} 2ax + b - d & (x<0) \\ [2ax + b - d, 2ax + b + d] & (x=0) \\ 2ax + b + d & (x>0) \\ \end{cases} \end{align}

あとはこれが0になる(0を含む)xxを求める。

まずx<0x<0の範囲を考える。f(x)=0f'(x)=0を解く。

f(x)=02ax+bd=0x=bd2a\begin{align} f'(x) &= 0 \\ 2ax + b - d &= 0 \\ x &= -\frac{b-d}{2a} \end{align}

そしてこのxx<0<0となるのはb>db>dの時なので、b>db>dの時、x=(bd)/2ax=-(b-d)/2af(x)f(x)の最適解となる。

同じようにx>0x>0の範囲を考える。

f(x)=02ax+b+d=0x=b+d2a\begin{align} f'(x) &= 0 \\ 2ax + b + d &= 0 \\ x &= -\frac{b+d}{2a} \end{align}

そしてこのxx>0>0となるのはb<db<-dの時なので、b<db<-dの時、x=(b+d)/2ax=-(b+d)/2af(x)f(x)の最適解となる。

最後に、どちらにも当てはまらない時、つまりdbd-d\leq b\leq dの時を考えるが、x<0x<0でもx>0x>0でもないことが確定しているので、x=0x=0f(x)f(x)の最適解となる。

まとめるとこうなる。

arg minxf(x)={b+d2a(b<d)0(dbd)bd2a(b>d)\begin{align} \argmin_x f(x) = \begin{cases} -\frac{b+d}{2a} & (b<-d) \\ 0 & (-d\leq b\leq d) \\ -\frac{b-d}{2a} & (b>d) \\ \end{cases} \end{align}

例を見てみよう。a,b,c,d=1,2,3,4a, b, c, d = 1, 2, 3, 4とすると、dbd-d\leq b\leq dになるのでx=0x=0が最適解となる。

Loading image...

またa,b,c,d=1,3,2,1a, b, c, d = 1, -3, 2, 1とすると、b<db<-dになるので

x=b+d2a=3+121=1\begin{align} x = -\frac{b+d}{2a}=-\frac{-3+1}{2\cdot1} = 1 \end{align}

が最適解となる。

Loading image...

ソフト閾値関数

Soft-Thresholding Function

以下の関数をソフト閾値関数や軟閾値作用素と呼ぶ。

S(x,λ)=sign(x)max(xλ,0)={x+λ(x<λ)0(λxλ)xλ(x>λ)λ0\begin{align} S(x, \lambda) &= \text{sign}(x) \max(|x|-\lambda, 0) \\ &= \begin{cases} x+\lambda & (x<-\lambda) \\ 0 & (-\lambda\leq x\leq \lambda) \\ x-\lambda & (x>\lambda) \\ \end{cases} \\ \lambda &\geq 0 \end{align}

Pythonで書くとこう。

def soft_thresholding(x, threshold):
    return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)

λ=1\lambda=1とすると、こんなグラフ。

Loading image...

λ\lambdaを大きくすると0になる範囲が広がる。

この関数を用いると、先の解を以下のように表せる。

arg minxf(x)={b+d2a(b<d)0(dbd)bd2a(b>d)=S(b2a,d2a)\begin{align} \argmin_x f(x) &= \begin{cases} -\frac{b+d}{2a} & (b<-d) \\ 0 & (-d\leq b\leq d) \\ -\frac{b-d}{2a} & (b>d) \\ \end{cases} \\ &= S\left(-\frac{b}{2a}, \frac{d}{2a}\right) \end{align}

Lasso回帰の最適化

ではいよいよLasso回帰の最適化を行おう。目的関数は以下とする。

JLas(w)=12NyXw22+αw1\begin{align} J_\text{Las}(\bm w) = \frac{1}{2N}\| \bm y - X \bm w \|_2^2 + \alpha \| \bm w \|_1 \end{align}

scikit-learnに倣い、一項目をデータ数NNで割った。どうせα\alphaで比重を調整できるので定数を掛けても問題ない。NNで割っておくとα\alphaに関する議論でNNをあまり意識しなくてよくなる(データ数に応じてα\alphaを調整する必要がなくなる)。

💁

ちなみにさっきからずっと誤差項に1/21/2を掛けているのは微分したときにきれいになるからってだけ。あと一応後の章で出てくる正規分布との兼ね合いもある。分散1の正規分布を仮定すると1/21/2だけ残るのでそれと一致して分かりやすいよねって感じ。


座標降下法を用いて最適化を行う。ある一つのパラメータwiw_iに着目し、それ以外のパラメータwj(ij)w_j\,(i\neq j)を固定した状態でwiw_iを最小化する。

目的関数をwiw_iについて整理しよう。

JLas(wi)=12NyXw22+αw1=12Nn=1N(y(n)wx(n))2+αw1=12Nn=1N(y(n)w\ix\i(n)wixi(n))2+α(w\i1+wi)\begin{align} J_\text{Las}(w_i) &= \frac{1}{2N} \| \bm y - X \bm w \|_2^2 + \alpha \| \bm w \|_1 \\ &= \frac{1}{2N} \sum_{n=1}^N \left( y^{(n)} - \bm w^\top\bm x^{(n)} \right)^2 + \alpha \| \bm w \|_1 \\ &= \frac{1}{2N} \sum_{n=1}^N \left( y^{(n)} - \bm w_{\backslash i}^\top\bm x_{\backslash i}^{(n)} - w_ix_i^{(n)} \right)^2 + \alpha(\| \bm w_{\backslash i} \|_1 + |w_i|) \\ \end{align}

x\i\bm x_{\backslash i}x\bm xii番目の要素xix_i以外を並べたベクトル。今後のためにiiとそれ以外で分けた。

ここで以下を定義する。

r\i(n)=y(n)w\ix\i(n)r\i=(r\i(1),r\i(2),,r\i(N))RNxi=(xi(1),xi(2),,xi(N))RN\begin{align} r_{\backslash i}^{(n)} &= y^{(n)} - \bm w_{\backslash i}^\top\bm x_{\backslash i}^{(n)} \\ \bm r_{\backslash i} &= (r_{\backslash i}^{(1)}, r_{\backslash i}^{(2)}, \cdots, r_{\backslash i}^{(N)})^\top \in \R^N \\ \bm x_i &= (x_i^{(1)}, x_i^{(2)}, \cdots, x_i^{(N)})^\top \in \R^N \end{align}

r\i(n)r_{\backslash i}^{(n)}x(n)\bm x^{(n)}ii番目の特徴量xi(n)x_i^{(n)}を考慮せずに予測した値の誤差。r\i\bm r_{\backslash i}はそれらを並べたベクトル。xi\bm x_iii番目の特徴量を並べたベクトル。これで変形を進める。

=12Nn=1N(r\i(n)wixi(n))2+α(w\i1+wi)=12Nn=1N((r\i(n))22r\i(n)wixi(n)+(wixi(n))2)+α(w\i1+wi)=12Nn=1N(2r\i(n)wixi(n)+(wixi(n))2)+αwi+const\begin{align} &= \frac{1}{2N} \sum_{n=1}^N \left( r_{\backslash i}^{(n)} - w_ix_i^{(n)} \right)^2 + \alpha(\| \bm w_{\backslash i} \|_1 + |w_i|) \\ &= \frac{1}{2N}\sum_{n=1}^N \left( (r_{\backslash i}^{(n)})^2 - 2r_{\backslash i}^{(n)}w_ix_i^{(n)} + (w_ix_i^{(n)})^2 \right) + \alpha(\| \bm w_{\backslash i} \|_1 + |w_i|) \\ &= \frac{1}{2N}\sum_{n=1}^N \left( -2r_{\backslash i}^{(n)}w_ix_i^{(n)} + (w_ix_i^{(n)})^2 \right) + \alpha |w_i| + \text{const} \\ \end{align}

wiw_iに関係のない項を定数としてconst\text{const}にまとめた。さらに変形を進めて

=12N(2win=1Nr\i(n)xi(n)+wi2n=1N(xi(n))2)+αwi+const=12N(2wir\ixi+wi2xi22)+αwi+const=1Nwir\ixi+12Nwi2xi22+αwi+const\begin{align} &= \frac{1}{2N} \left( -2w_i \sum_{n=1}^N r_{\backslash i}^{(n)}x_i^{(n)} + w_i^2 \sum_{n=1}^N (x_i^{(n)})^2 \right) + \alpha |w_i| + \text{const} \\ &= \frac{1}{2N} \left( -2w_i\bm r_{\backslash i}^\top\bm x_i + w_i^2\|\bm x_i\|_2^2 \right) + \alpha |w_i| + \text{const} \\ &= -\frac{1}{N}w_i\bm r_{\backslash i}^\top\bm x_i + \frac{1}{2N}w_i^2\|\bm x_i\|_2^2 + \alpha |w_i| + \text{const} \\ \end{align}

wi2w_i^2について整理すると

JLas(wi)=12Nxi22wi21Nr\ixiwi+const+αwi\begin{align} J_\text{Las}(w_i) = \frac{1}{2N} \| \bm x_i \|_2^2w_i^2 - \frac{1}{N}\bm r_{\backslash i}^\top\bm x_iw_i + \text{const} + \alpha|w_i| \end{align}

前章で解いたf(x)=ax2+bx+c+dx(a,d>0)f(x)=ax^2+bx+c+d|x|\quad(a, d > 0)と同じ形になった。この解はS(b2a,d2a)S\left(-\frac{b}{2a}, \frac{d}{2a}\right)であるため、

a=12Nxi22b=1Nr\ixid=α\begin{align} a &= \frac{1}{2N} \| \bm x_i \|_2^2 \\ b &= -\frac{1}{N}\bm r_{\backslash i}^\top\bm x_i \\ d &= \alpha \\ \end{align}

を代入すると

arg minwiJLas(wi)=S(r\ixixi22,Nαxi22)\begin{align} \argmin_{w_i} J_\text{Las}(w_i) = S \left( \frac{\bm r_{\backslash i}^\top\bm x_i}{\|\bm x_i\|_2^2}, \frac{N\alpha}{\|\bm x_i\|_2^2} \right) \end{align}

が得られる。これが各ステップにおけるwiw_iの更新先となる。

また、バイアス項に対応するパラメータは正則化をしないので、解が変わる。w0w_0をバイアスとし、正則化項を無視すると、

JLas(w0)=12Nx022w021Nr\0x0w0+const\begin{align} J_\text{Las}(w_0) = \frac{1}{2N} \|\bm x_0 \|_2^2w_0^2 - \frac{1}{N}\bm r_{\backslash 0}^\top\bm x_0w_0 + \text{const} \end{align}

となるので、これを微分し、

JLasw0=1Nx022w01Nr\0x0\begin{align} \frac{\partial J_\text{Las}}{\partial w_0} = \frac{1}{N}\|\bm x_0\|_2^2w_0 - \frac{1}{N}\bm r_{\backslash 0}^\top\bm x_0 \end{align}

=0=0とおいてw0w_0について解く。

w0=r\0x0x022\begin{align} w_0 = \frac{\bm r_{\backslash 0}^\top\bm x_0}{\|\bm x_0\|_2^2} \end{align}

w0w_0をバイアスとする場合、x0=(1,1,,1)RN\bm x_0=(1, 1, \cdots, 1)^\top\in\R^Nとなるので、

w0=n=1Nr\0(n)N=r\0ˉ\begin{align} w_0 &= \frac{\sum_{n=1}^Nr_{\backslash 0}^{(n)}}{N} \\ &= \bar{r_{\backslash 0}} \\ \end{align}

となる。バイアス項を無視した予測値と正解の差の平均だね。

以上で全てのパラメータの更新先が得られた。最後にLasso回帰の最適化手順をまとめる。

  1. w\bm wを初期化
  2. w0w_0を最適化: w0r\0ˉw_0 \leftarrow \bar{r_{\backslash 0}}
  3. w1,w2,,wmw_1, w_2, \cdots, w_mを最適化: wiS(r\ixixi22,Nαxi22)w_i \leftarrow S \left( \frac{\bm r_{\backslash i}^\top\bm x_i}{\|\bm x_i\|_2^2}, \frac{N\alpha}{\|\bm x_i\|_2^2} \right)
  4. 収束するまで2, 3を繰り返す

Python実装

実装してみよう。

import numpy as np
 
class Lasso:
    def __init__(self, alpha: float = 1., max_iter: int = 1000):
        self.alpha = alpha
        self.max_iter = max_iter # 繰り返しの上限
        self.weights = None
 
    @staticmethod
    def soft_thresholding(x, threshold):
        return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)
 
    def fit(self, X, y):
        X = np.insert(X, 0, 1, axis=1)
        n_samples, n_features = X.shape
        self.weights = np.zeros(n_features) # パラメータの初期化
        for _ in range(self.max_iter): # 上限まで繰り返し
            # w0の更新
            self.weights[0] = np.mean(y - X[:, 1:] @ self.weights[1:])
 
            # w1, w2, ..., wmの更新
            for i in range(1, n_features):
                r_i = y - np.delete(X, i, 1) @ np.delete(self.weights, i)
                x_i = X[:, i]
                norm = np.linalg.norm(x_i) ** 2
                self.weights[i] = self.soft_thresholding(
                    r_i @ x_i / norm, n_samples * self.alpha / norm
                )
 
    def predict(self, X):
        X = np.insert(X, 0, 1, axis=1)
        return X @ self.weights

収束するまでと書いたが、コードをシンプルにするため、繰り返しの上限max_iterを導入し、上限に達したら終了するようにした。

適当なデータで学習させてみる。

from sklearn.datasets import load_diabetes
X, y = load_diabetes(return_X_y=True)
 
model = Lasso(alpha=1.)
model.fit(X, y)
print("\n".join(model.weights.astype(str).tolist()))
152.133484162896
0.0
-0.0
367.70162582143126
6.30970264417499
0.0
0.0
-0.0
0.0
307.60214746219583
0.0

sklearnとも一致する。

from sklearn.linear_model import Lasso as LassoSklearn
 
model = LassoSklearn(alpha=1., tol=0., max_iter=1000)
model.fit(X, y)
weights = np.append([model.intercept_], model.coef_)
print("\n".join(weights.astype(str).tolist()))
152.133484162896
0.0
-0.0
367.70162582143115
6.309702644174996
0.0
0.0
-0.0
0.0
307.6021474621959
0.0

微小な誤差が生じているが、細かい実装方法が違うだけかな。

Lasso回帰の性質

最適化から離れた、ちょっとしたLasso回帰の小話

スパース推定

前章で得られた解を見てみると、多くの要素が0になっていることが分かる。このような多くの要素が0であるベクトルはスパースなベクトルや疎ベクトルと呼ばれる。反対に多くの要素が0でないベクトルは密ベクトルと呼ばれる。

こういったスパースな解を得ることはスパース推定やスパースモデリングと呼ばれる。スパース推定によって多くのデータの中から重要なデータのみを取り出すことができる。Lasso回帰では、複数の特徴量の中から目的変数の予測に寄与するものだけを取り出すことができる。Lasso回帰は特徴量選択という面で非常に解釈性の高いモデルと言える。

なぜLasso回帰によってスパースな解が得られるかはパラメータの更新式を見ると分かりやすい。

wiS(r\ixixi22,Nαxi22)\begin{align} w_i \leftarrow S \left( \frac{\bm r_{\backslash i}^\top\bm x_i}{\|\bm x_i\|_2^2}, \frac{N\alpha}{\|\bm x_i\|_2^2} \right) \end{align}

ソフト閾値関数S(x,λ)S(x, \lambda)xxの絶対値が閾値λ\lambdaを超えない場合に0となる関数である。α\alphaを大きくすると閾値が大きくなり、結果、wiw_iが0になることが多くなる。

ちなみに、L1ではなくL0ノルムによる正則化を行う手法もある。L0ノルムはベクトルの非ゼロの要素の数を表すため、より厳密なスパース推定が可能になる。ただ計算量的に最適化が困難であるため、多くの場合はL1ノルムによる正則化が用いられる。

ベイズ的解釈

線形回帰の最小二乗法は誤差に正規分布を仮定した最尤推定と解釈できる。実際にその仮定で対数尤度の計算を進めると

p(yx;w)=N(y;wx,1)lnp(yX;w)=nlnp(y(n)x(n);w)=nln(12πexp((y(n)wx(n))22))=12n(y(n)wx(n))2+const=12yXw22+const\begin{align} p(y|\bm x; \bm w) &= \mathcal N(y; \bm w^\top\bm x,1) \\ \ln p(\bm y| X; \bm w) &= \sum_n \ln p(y^{(n)}| \bm x^{(n)}; \bm w) \\ &= \sum_n \ln \left( \frac{1}{\sqrt{2\pi}} \exp \left( -\frac{(y^{(n)} - \bm w^\top \bm x^{(n)})^2}{2} \right) \right) \\ &= -\frac{1}{2} \sum_n (y^{(n)} - \bm w^\top \bm x^{(n)})^2 + \text{const} \\ &= -\frac{1}{2} \| \bm y - X\bm w \|_2^2 + \text{const} \end{align}

負の二乗和誤差が出てくる。この最大化は二乗和誤差の最小化と同じである。ではここに正則化項を足すことは何を意味するか。

ここに正則化項を足すことは、確率モデルに事前分布p(w)p(\bm w)を仮定することと同じ意味になる。例えばL2正則化はパラメータの事前分布として平均0の正規分布を仮定することと解釈できる。

p(w)=N(w;0,1λI)lnp(wX,y)=lnp(yX;w)+lnp(w)=12yXw2212λw22+const\begin{align} p(\bm w) &= \mathcal N(\bm w; \bm 0, \frac{1}{\lambda} I) \\ \ln p(\bm w| X, \bm y) &= \ln p(\bm y| X; \bm w) + \ln p(\bm w) \\ &= -\frac{1}{2} \| \bm y - X\bm w \|_2^2 - \frac{1}{2}\lambda \| \bm w\|_2^2 + \text{const} \end{align}

Lasso回帰はL1正則化を行うが、これはパラメータの事前分布としてラプラス分布を仮定していると解釈できる。

p(w)=Lap(w;0,λ)lnp(wX,y)=lnp(yX;w)+lnp(w)=12yXw22λw1+const\begin{align} p(\bm w) &= \text{Lap}(\bm w; \bm 0, \lambda) \\ \ln p(\bm w| X, \bm y) &= \ln p(\bm y| X; \bm w) + \ln p(\bm w) \\ &= -\frac{1}{2} \| \bm y - X\bm w \|_2^2 - \lambda \| \bm w\|_1 + \text{const} \\ \end{align}

このあたりの細かい話は別の記事に書いたので、興味があればぜひ: 線形回帰における最尤推定・MAP推定・ベイズ推定

ElasticNet

Ridge回帰とLasso回帰を組み合わせたモデル。L1正則化とL2正則化を両方行う。次の目的関数を最適化する。

JEN(w)=12NyXw22+α(βw1+(1β)12w22)\begin{align} J_\text{EN}(\bm w) = \frac{1}{2N} \| \bm y - \bm X \bm w \|_2^2 + \alpha \left( \beta \| \bm w \|_1 + (1-\beta) \frac{1}{2} \| \bm w \|_2^2 \right) \end{align}

β\betaはL1正則化の割合を表すハイパーパラメータ。β=1\beta=1とするとLasso回帰、β=0\beta=0とするとRidge回帰と一致する。

解き方はLassoと一緒。絶対値が含まれているので解析的には解けない。ここでも座標降下法を使う。Lasso同様、ある一つのパラメータwiw_iに着目した目的関数J(wi)J(w_i)が必要。上の式を変形していっても良いが、面倒くさいので、Lasso回帰とどこが変わったかを考える。

Lasso回帰の目的関数は以下であった。

JLas(wi)=12Nxi22wi21Nr\ixiwi+const+αwi\begin{align} J_\text{Las}(w_i) = \frac{1}{2N} \| \bm x_i \|_2^2w_i^2 - \frac{1}{N} \bm r_{\backslash i}^\top\bm x_iw_i + \text{const} + \alpha|w_i| \end{align}

各項の係数がどう変化するかを考えよう。L2正則化は各パラメータwiw_iの二乗を足すので、増えた分を二乗の項の係数に足せば良い。つまりハイパーパラメータと定数を掛けたα(1β)/2\alpha(1-\beta)/2を足せば良い。またL1ノルムには新たなパラメータβ\betaが追加で掛けられる。これらをまとめるとこうなる。

JEN(wi)=12N(xi22+Nα(1β))wi21Nr\ixiwi+const+αβwi\begin{align} J_\text{EN}(w_i) = \frac{1}{2N}(\|\bm x_i\|_2^2 + N\alpha(1-\beta))w_i^2 - \frac{1}{N}\bm r_{\backslash i}^\top\bm x_iw_i + \text{const} + \alpha\beta|w_i| \end{align}

ここからの解き方は同じ。

a=12N(xi22+Nα(1β))b=1Nr\ixid=αβ\begin{align} a &= \frac{1}{2N}\big(\|\bm x_i\|_2^2 + N\alpha(1-\beta)\big) \\ b &= -\frac{1}{N}\bm r_{\backslash i}^\top\bm x_i \\ d &= \alpha\beta \end{align}

とすると、解はS(b2a,d2a)S\left(-\frac{b}{2a}, \frac{d}{2a}\right)なので

arg minwiJEN(wi)=S(r\ixixi22+Nα(1β),Nαβxi22+Nα(1β))\begin{align} \argmin_{w_i} J_\text{EN}(w_i) = S \left( \frac{\bm r_{\backslash i}^\top\bm x_i}{\|\bm x_i\|_2^2 + N\alpha(1-\beta)}, \frac{N\alpha\beta}{\|\bm x_i\|_2^2 + N\alpha(1-\beta)} \right) \end{align}

が得られる。

またバイアスについては正則化項を考慮しないため先と同じく

w0=r\0ˉ\begin{align} w_0 = \bar{r_{\backslash 0}} \end{align}

となる。

実装してみよう。

class ElasticNet:
    def __init__(
        self,
        alpha: float = 1.,
        beta: float = 0.5,
        max_iter: int = 1000
    ):
        self.alpha = alpha
        self.beta = beta
        self.max_iter = max_iter
        self.weights = None
 
    @staticmethod
    def soft_thresholding(x, threshold):
        return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)
 
    def fit(self, X, y):
        X = np.insert(X, 0, 1, axis=1)
        n_samples, n_features = X.shape
        self.weights = np.zeros(n_features)
        for _ in range(self.max_iter):
            self.weights[0] = np.mean(y - X[:, 1:] @ self.weights[1:])
            for i in range(1, n_features):
                r_i = y - np.delete(X, i, 1) @ np.delete(self.weights, i)
                x_i = X[:, i]
                norm = np.linalg.norm(x_i) ** 2
                norm += n_samples * self.alpha * (1 - self.beta) # 追加
                self.weights[i] = self.soft_thresholding(
                    r_i @ x_i / norm,
                    n_samples * self.alpha * self.beta / norm # 変更
                )
 
    def predict(self, X):
        X = np.insert(X, 0, 1, axis=1)
        return X @ self.weights

更新式が少しだけ変わった。分母はノルムだけではないからnormという変数名を変えるか迷ったけど、normには規格という意味もあるみたいなのでそのままにした。

model = ElasticNet(alpha=1., beta=0.5)
model.fit(X, y)
print("\n".join(model.weights.astype(str).tolist()))
152.13348416289594
0.3590175634148627
0.0
3.259766998005527
2.2043402383839803
0.5286453997828984
0.2509350904357106
-1.8613631921210814
2.1144540777001035
3.105834685472744
1.7698510183435376

sklearnとも一致する。

from sklearn.linear_model import ElasticNet as ElasticNetSklearn
 
model = ElasticNetSklearn(alpha=1., l1_ratio=0.5, tol=0., max_iter=1000)
model.fit(X, y)
weights = np.append([model.intercept_], model.coef_)
print("\n".join(weights.astype(str).tolist()))
152.13348416289594
0.3590175634148638
0.0
3.2597669980055266
2.204340238383981
0.5286453997828972
0.2509350904357103
-1.8613631921210825
2.1144540777001053
3.105834685472747
1.7698510183435392

オワリ

おつ。