r0w0

PythonやDeepLearning関連で学んだこと、調べたことの備忘録

PyTorchのTensor.backwardとoptimizer.updateの関係

状況

長きに渡るデータ前処理の日々を過ごしていたらそれ以外の基本的なことを思い出せなくなっていたでござる。First In First Out.

タイトルの通り、PyTorchのTensor.backwardとoptimizer.updateって何をしているのか、どういった関係にあるのかを以下に簡単にまとめる。

パラメータとみなすテンソルの指定

この過去記事で触れたように、テンソル微分対象となる変数かそうで無いかのフラグを持つ。フラグの立っているテンソルはパラメータと呼ばれる。通常、重みテンソルがパラメータとして指定される。

nn.Linearで生成された重みテンソルは標準でこのフラグが立っている。

勾配計算対象テンソルの指定

optimizerの定義時に、何のテンソルについて勾配を用いて更新するのか指定する。

基本的にはパラメータ、つまり重みテンソルを指定する。

optim.SGD(model.parameters(), lr=0.01, momentum=0.9) は、model.parameters()の返すテンソルたちを更新するよと宣言している。

勾配の初期化

テンソルの更新に用いる勾配は、テンソルの .grad に保持されている。その値を0に初期化する。

optimizer.zero_grad()の実行により達成される。

後述する.backwardの動作として、既存の.gradへの上書きではなく加算。従って、意図的に加算したい場合以外は学習イテレーションの都度、.zero_grad()を実行する必要がある。

勾配計算

計算処理の結果の値(基本的にはLoss)のテンソルに対して .backward() を実行する。これにより、入力から出力までの計算グラフにおいて、パラメータを変数とした微分計算が行われる。微分結果、即ち勾配はパラメータのテンソルに保持される(Tensor.gradに保持)。

勾配計算結果をパラメータに反映

opzimizer.step() は「1.」で指定されたテンソルについて、テンソルの.gradの値を用いてテンソル値の更新を行う

参考

machine learning - pytorch - connection between loss.backward() and optimizer.step() - Stack Overflow