差分

このページの2つのバージョン間の差分を表示します。

この比較画面へのリンク

両方とも前のリビジョン 前のリビジョン
次のリビジョン
前のリビジョン
intro:researches:machine [2023/05/29 11:21] – [ステップ数$N$とバッチサイズ$b$の関係] Naoki SATOintro:researches:machine [2023/06/02 13:40] (現在) – [収束解析] Naoki SATO
行 103: 行 103:
 ( \sigma_D^2 + M_D^2 )}}_{C_D} ( \sigma_D^2 + M_D^2 )}}_{C_D}
 \end{align} \end{align}
-ただし、$\alpha^G, \alpha^D$はoptimizerの学習率で、Nはステップ数、bはバッチサイズです。その他の定義については論文を参照してください。このことから、ステップ数$N$とバッチサイズ$b$を大きく、学習率$\alpha^G, \alpha^D$は小さくすれば、それぞれの右辺は0に近くなり、局所的ナッシュ均衡を近似できることがわかります。+ただし、$\alpha^G, \alpha^D$はoptimizerの学習率で、Nはステップ数、bはバッチサイズです。その他の定義については[[https://arxiv.org/pdf/2201.11989.pdf|論文]]を参照してください。このことから、ステップ数$N$とバッチサイズ$b$を大きく、学習率$\alpha^G, \alpha^D$は小さくすれば、それぞれの右辺は0に近くなり、局所的ナッシュ均衡を近似できることがわかります。
  
 ==== ステップ数$N$とバッチサイズ$b$の関係 ==== ==== ステップ数$N$とバッチサイズ$b$の関係 ====
行 129: 行 129:
 0 0
 \end{align} \end{align}
-が成り立ちます。1階微分が$0$以下で、2階微分が$0$以上なので、$N_G(b)$は単調減少で、かつ凸関数であることがわかります。$N_D(b)$についても同様です。+が成り立ちます。1階微分が$0$以下で、2階微分が$0$以上なので、$N_G(b)$は$b$に対して単調減少で、かつ凸関数であることがわかります。$N_D(b)$についても同様です。\\
 このことから、学習が収束するまでに必要なステップ数$N$を最小にするバッチサイズ$b$が存在することが分かります。 このことから、学習が収束するまでに必要なステップ数$N$を最小にするバッチサイズ$b$が存在することが分かります。
 ==== クリティカルバッチサイズの存在 ==== ==== クリティカルバッチサイズの存在 ====
行 136: 行 136:
 そこで、GANの訓練にもクリティカルバッチサイズが存在するかどうかを考えてみましょう。 そこで、GANの訓練にもクリティカルバッチサイズが存在するかどうかを考えてみましょう。
  
 +学習の計算量の指標である確率的勾配計算コスト(SFO計算量:stochastic first-order oracle complexity)は、それぞれ$N_G(b)b,N_D(b)b$で定義できます。これは、$1$回の反復で$b$個の確率的勾配を計算するからです。
 +\begin{align}
 +N_G(b)b
 +=
 +\frac{A_G b^2}{(\epsilon_G^2 - C_G)b - B_G}, \text{  }
 +N_D(b)b
 +=
 +\frac{A_D b^2}{(\epsilon_D^2 - C_D)b - B_D}
 +\end{align}
 +先ほどと同様に微分して形状を調べてみましょう。
 +\begin{align}
 +\frac{\mathrm{d}N_G(b)b}{\mathrm{d}b}
 +=
 +\frac{A_G b\{(\epsilon_G^2-C_G)b -2B_G\}}{\{(\epsilon_G^2-C_G)b-B_G\}^2}
 +, \text{  }
 +\frac{\mathrm{d}^2 N_G(b)b}{\mathrm{d}b^2}
 +=
 +\frac{2A_G B_G^2}{\{(\epsilon_G^2-C_G)b-B_G\}^3}
 +\geq
 +0
 +\end{align}
 +が成り立ちます。2階微分が常に$0$以上なので、$N_G(b)b$は$b$に対して凸関数であることが分かります。$N_D(b)b$についても同様です。
 +また、1階微分から、$N_G(b)b, N_D(b)b$を最小にする$b_G$と$b_D$は、次のように書けます。
 +\begin{align}
 +b_G^\star := \frac{2B_G}{\epsilon_G^2-C_G},
 +\text{  }
 +b_D^\star := \frac{2B_D}{\epsilon_D^2-C_D}
 +\end{align}
 +この$b_G^\star$と$b_D^\star$がクリティカルバッチサイズです。
 ==== クリティカルバッチサイズの推定 ==== ==== クリティカルバッチサイズの推定 ====
 +ここまでで、GANの訓練にもクリティカルバッチサイズが存在することがわかりました。このクリティカルバッチサイズを、事前に知ることを目指します。\\
 +GANの目標は綺麗な生成画像を得ることです。学習の停止条件にも、生成画像の品質が利用されます。なので、ここでは生成器のクリティカルバッチサイズについて考えてみましょう。
 +生成器$G$のクリティカルバッチサイズは、次のように表すことができました。
 +\begin{align}
 +b_G^\star := \frac{2B_G}{\epsilon_G^2 -C_G}
 +\end{align}
 +$B_G$や$C_G$の定義に立ち返って式変形をすると、$b_G^\star$の下界をoptimizerごとに次のように表すことができます。
 +  * Adam
 +\begin{align}
 +b_G^\star \geq 
 +\frac{\sigma_G^2}{\epsilon_G^3}\frac{\alpha^G}{(1-\beta_1^G)^3 \sqrt{\frac{\Theta}{1-\beta_2^G} \frac{1}{|S|^2}}}
 +\end{align}
 +  * AdaBelief
 +\begin{align}
 +b_G^\star \geq 
 +\frac{\sigma_G^2}{\epsilon_G^3}\frac{\alpha^G}{(1-\beta_1^G)^3 \sqrt{\frac{4\Theta}{1-\beta_2^G} \frac{1}{|S|^2}}}
 +\end{align}
 +  * RMSProp
 +\begin{align}
 +b_G^\star \geq 
 +\frac{\sigma_G^2}{\epsilon_G^3}\frac{\alpha^G}{\sqrt{\frac{\Theta}{|S|^2}}}
 +\end{align}
 +このとき、$\sigma_G^2 / \epsilon_G^3$のみ未知です。$\alpha^G, \beta_1^G, \beta_2^G$はoptimizerのハイパーパラメータなので、ユーザーが自由に設定できます。GANの訓練では、$\alpha^G=0.0001, \beta_1^G=0.5, \beta_2^G=0.999$などとするのが一般的です。
 +$\Theta$は生成器の次元でした。例えばDCGAN architectureならば、$\Theta=3,576,704$です。$|S|$はデータセットのデータの総数です。例えばLSUN Bedroomデータセットならば、$|S|=3,033,042$です。
 +
 +それでは、未知数である$\sigma_G^2 / \epsilon_G^3$について考えてみましょう。これは決して事前には分からないので、クリティカルバッチサイズの測定値と、先ほどの下界の推定式を利用して逆算します。いくつかの準備が必要なので、順番に見ていきましょう。\\
 +  1. $Nb-b$グラフ\\
 +DCGANを訓練して、LSUN Bedroomデータセットの実画像と瓜二つの生成画像を作ります。その学習のクリティカルバッチサイズの測定値を求めましょう。
 +クリティカルバッチサイズは、SFO計算量である$Nb$を最小にする$b$だったので、その測定値を知るためには、$Nb-b$グラフ(縦軸が$Nb$で、横軸が$b$のグラフ)を用意する必要があります。ここで、$N$は学習が収束するまでに必要なステップ数です。$b$はバッチサイズで、機械学習では慣例的に$2$の累乗を利用することが一般的です。理論的には$Nb$は$b$に対して凸関数であるはずなので、例えば次のようなグラフが得られるはずです。
 +{{ :intro:researches:nb-b.png?400 |}}
 +\\
 +  2. FID\\
 +一般的に、GANの学習の停止条件にはFID(Fr\'echet Inception Distance)と呼ばれる指標が利用されます。FIDは2つの画像(※実際には2つの正規分布)の離れ具合を測る指標で、その値が低ければ低いほど2つの画像が似ていることを意味します。全く同じ画像同士のFIDは$0$となります。GANの学習が収束するとき、生成器は綺麗な生成画像を出力できるようになっているはずですから、その生成画像と実画像のFIDは十分に低いはずです。
 +ここまでのことを踏まえて、『学習が収束するまでに必要なステップ数$N$』を、『十分に低いFIDを達成するまでに必要なステップ数$N$』と読み替えます。
 +LSUN Bedroomデータセットを使って、DCGANを訓練する場合、FIDの極限は41.8程度です。学習の不安定性を考慮して、今回の実験では、FID=70を十分に低いFIDだとして、FID=70を停止条件にします。\\
 +
 +{{ :intro:researches:fid.png?800 |}}
 +\\
 +  3. 学習率\\
 +学習率は学習の結果に大きく影響するので、事前に適切な学習率を探索することが極めて重要です。GANの学習は非常に不安定です。DCGANは標準的なGANなので、特に学習率などのハイパーパラメータにとても敏感に影響を受けます。GANには生成器と識別器があるので、ある一定のステップ数の学習で、最も低いFIDを達成することができる学習率の組み合わせ$(\alpha^D, \alpha^G)$を探します。DCGANの場合は、次のような結果になります。
 +
 +{{ :intro:researches:grid.png?1200 |}}
 +
 +バッチサイズは$64$で固定し、Adamは$60000$steps、AdaBeliefは$120000$steps、RMSPropは$180000$stepsの結果です。青色が濃いほど低いFIDを達成できたことを意味します。
 +これによると、最も良い学習率の組み合わせは、optimizerごとに次のようになります。
 +
 +{{ :intro:researches:hyper.png?800 |}}
 +
 +なお、$\beta_1$と$\beta_2$は一般的な値を使用します。\\
 +\\
 +それでは、いよいよバッチサイズを変えて、『FID=70を達成するまでに必要なステップ数』を計測していきましょう。
 +バッチサイズは大きければ大きいほどGPUをたくさん使用するので、計算機の都合上、今回の実験では、$b=[4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]$の範囲で実験しました。
 +$Nb-b$グラフを作ると、次のようになります。
 +
 +{{ :intro:researches:nbb01.png?600 |}}
 +
 +予想通り、全てのoptimizerで凸関数になっています。このグラフの最小値をとるような$b$が、それぞれのクリティカルバッチサイズです。これで、Adamのクリティカルバッチサイズの測定値は$2^5=32$であることが分かりました。これを使って、未知数$\sigma_G^2 / \epsilon_G^3$を逆算しましょう。Adamの推定式をもう一度掲載します。
 +\begin{align}
 +b_G^\star \geq 
 +\frac{\sigma_G^2}{\epsilon_G^3}\frac{\alpha^G}{(1-\beta_1^G)^3 \sqrt{\frac{\Theta}{1-\beta_2^G} \frac{1}{|S|^2}}}
 +\end{align}
 +これを未知数$\sigma_G^2 / \epsilon_G^3$について書き直して、既知のパラメータに使用した値を代入すると、
 +\begin{align}
 +\frac{\sigma_G^2}{\epsilon_G^3} \leq 788.7
 +\end{align}
 +とできます。これを使えば、AdaBeliefとRMSPropの推定値を計算できます。推定値は四角で、測定値は丸でマーキングしてあります。
 +さらに、生成器のモデルにDCGAN architectureを採用している場合は、この比$\sigma_G^2 / \epsilon_G^3$を適用できるので、別のGANでDCGAN architectureを採用している場合にもこの推定式は有効です。実際、WGAN-GPでCelebAデータセットを訓練する場合にも、推定値と測定値は近くなります。
 +
 +{{ :intro:researches:estimated.png?600 |}}
 +
 +DCGANがSection4.1で、WGAN-GPがSection4.2に当たります。RMSProp以外では完全に推定に成功していることが分かります。RMSPropで推定が上手くいかない原因は、RMSPropのクリティカルバッチサイズの推定式に$\beta_1$と$\beta_2$が含まれていないことであると考えられます。
  
 Naoki Sato, Hideaki Iiduka: Existence and Estimation of Critical Batch Size for Training Generative Adversarial Networks with Two Time-Scale Update Rule, Proceedings of The 40th International Conference on Machine Learning, PMLR 202: ??–?? (2023) Naoki Sato, Hideaki Iiduka: Existence and Estimation of Critical Batch Size for Training Generative Adversarial Networks with Two Time-Scale Update Rule, Proceedings of The 40th International Conference on Machine Learning, PMLR 202: ??–?? (2023)
  • intro/researches/machine.1685326897.txt.gz
  • 最終更新: 2023/05/29 11:21
  • by Naoki SATO