差分
このページの2つのバージョン間の差分を表示します。
両方とも前のリビジョン 前のリビジョン 次のリビジョン | 前のリビジョン 次のリビジョン両方とも次のリビジョン | ||
intro:researches:machine [2023/05/29 11:48] – [クリティカルバッチサイズの存在] Naoki SATO | intro:researches:machine [2023/05/29 17:30] – [クリティカルバッチサイズの推定] Naoki SATO | ||
---|---|---|---|
行 166: | 行 166: | ||
この$b_G^\star$と$b_D^\star$がクリティカルバッチサイズです。 | この$b_G^\star$と$b_D^\star$がクリティカルバッチサイズです。 | ||
==== クリティカルバッチサイズの推定 ==== | ==== クリティカルバッチサイズの推定 ==== | ||
+ | ここまでで、GANの訓練にもクリティカルバッチサイズが存在することがわかりました。このクリティカルバッチサイズを、事前に知ることを目指します。\\ | ||
+ | GANの目標は綺麗な生成画像を得ることです。学習の停止条件にも、生成画像の品質が利用されます。なので、ここでは生成器のクリティカルバッチサイズについて考えてみましょう。 | ||
+ | 生成器$G$のクリティカルバッチサイズは、次のように表すことができました。 | ||
+ | \begin{align} | ||
+ | b_G^\star := \frac{2B_G}{\epsilon_G^2 -C_G} | ||
+ | \end{align} | ||
+ | $B_G$や$C_G$の定義に立ち返って式変形をすると、$b_G^\star$の下界をoptimizerごとに次のように表すことができます。 | ||
+ | * Adam | ||
+ | \begin{align} | ||
+ | b_G^\star \geq | ||
+ | \frac{\sigma_G^2}{\epsilon_G^3}\frac{\alpha^G}{(1-\beta_1^G)^3 \sqrt{\frac{\Theta}{1-\beta_2^G} \frac{1}{|S|^2}}} | ||
+ | \end{align} | ||
+ | * AdaBelief | ||
+ | \begin{align} | ||
+ | b_G^\star \geq | ||
+ | \frac{\sigma_G^2}{\epsilon_G^3}\frac{\alpha^G}{(1-\beta_1^G)^3 \sqrt{\frac{4\Theta}{1-\beta_2^G} \frac{1}{|S|^2}}} | ||
+ | \end{align} | ||
+ | * RMSProp | ||
+ | \begin{align} | ||
+ | b_G^\star \geq | ||
+ | \frac{\sigma_G^2}{\epsilon_G^3}\frac{\alpha^G}{\sqrt{\frac{\Theta}{|S|^2}}} | ||
+ | \end{align} | ||
+ | このとき、$\sigma_G^2 / \epsilon_G^3$のみ未知です。$\alpha^G, | ||
+ | $\Theta$は生成器の次元でした。例えばDCGAN architectureならば、$\Theta=3, | ||
+ | |||
+ | それでは、未知数である$\sigma_G^2 / \epsilon_G^3$について考えてみましょう。これは決して事前には分からないので、クリティカルバッチサイズの測定値と、先ほどの下界の推定式を利用して逆算します。いくつかの準備が必要なので、順番に見ていきましょう。\\ | ||
+ | 1. $Nb-b$グラフ\\ | ||
+ | DCGANを訓練して、LSUN Bedroomデータセットの実画像と瓜二つの生成画像を作ります。その学習のクリティカルバッチサイズの測定値を求めましょう。 | ||
+ | クリティカルバッチサイズは、SFO計算量である$Nb$を最小にする$b$だったので、その測定値を知るためには、$Nb-b$グラフ(縦軸が$Nb$で、横軸が$b$のグラフ)を用意する必要があります。ここで、$N$は学習が収束するまでに必要なステップ数です。$b$はバッチサイズで、機械学習では慣例的に$2$の累乗を利用することが一般的です。理論的には$Nb$は$b$に対して凸関数であるはずなので、例えば次のようなグラフが得られるはずです。 | ||
+ | {{ : | ||
+ | \\ | ||
+ | 2. FID\\ | ||
+ | 一般的に、GANの学習の停止条件にはFID(Fr\' | ||
+ | ここまでのことを踏まえて、『学習が収束するまでに必要なステップ数$N$』を、『十分に低いFIDを達成するまでに必要なステップ数$N$』と読み替えます。 | ||
+ | LSUN Bedroomデータセットを使って、DCGANを訓練する場合、FIDの極限は41.8程度です。学習の不安定性を考慮して、今回の実験では、FID=70を十分に低いFIDだとして、FID=70を停止条件にします。\\ | ||
+ | \\ | ||
+ | 3. 学習率\\ | ||
+ | 学習率は学習の結果に大きく影響するので、事前に適切な学習率を探索することが極めて重要です。GANの学習は非常に不安定です。DCGANは標準的なGANなので、特に学習率などのハイパーパラメータにとても敏感に影響を受けます。GANには生成器と識別器があるので、ある一定のステップ数の学習で、最も低いFIDを達成することができる学習率の組み合わせ$(\alpha^D, | ||
+ | |||
+ | {{ : | ||
+ | |||
+ | バッチサイズは$64$で固定し、Adamは$60000$steps、AdaBeliefは$120000$steps、RMSPropは$180000$stepsの結果です。青色が濃いほど低いFIDを達成できたことを意味します。 | ||
+ | これによると、最も良い学習率の組み合わせは、optimizerごとに次のようになります。 | ||
+ | |||
+ | {{ : | ||
+ | |||
+ | |||
Naoki Sato, Hideaki Iiduka: Existence and Estimation of Critical Batch Size for Training Generative Adversarial Networks with Two Time-Scale Update Rule, Proceedings of The 40th International Conference on Machine Learning, PMLR 202: ??–?? (2023) | Naoki Sato, Hideaki Iiduka: Existence and Estimation of Critical Batch Size for Training Generative Adversarial Networks with Two Time-Scale Update Rule, Proceedings of The 40th International Conference on Machine Learning, PMLR 202: ??–?? (2023) |