nikkie-ftnextの日記

イベントレポートや読書メモを発信

PyTorchのモデル(nn.Module)には、これから訓練に入ることをtrain()メソッドで教えてあげよう

はじめに

ぴえぴえ...🤯 nikkieです。

Today I learnedです。
開発者がtorch.nn.Moduleお世話をするんだ!

目次

私にはTrainerの中がブラックボックス

BERT系のモデルやLLMをファインチューン1していますが、ぶっちゃけtransformers.Trainertrain()メソッドを呼んだ先はブラックボックスです2
少しは理解したいと手に取ったこちらの動画

動画のコードには、訓練ループに入る前にmodel.train()が出てきます(3:38〜)

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        # batchをmodelで推論してlossを計算していくが、今回の記事ではスコープアウト
        ...

またmodel.eval()というメソッド呼び出しも見たことがあります。
今回はこれらが何かを調べました。

nn.Module.train()nn.Module.eval()ってなんだ?

モデル(nn.Module)には、modeがある

まず、nn.Module.train()1つの引数modeを持ちます。
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.train modeのデフォルト値はTrueで、model.train()という呼び出しはmodel.train(mode=True)ということです。

model.train(mode=False)という呼び出しもできるのですが、これがnn.Module.eval()です。
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval

This is equivalent with self.train(False).

再びnn.Module.train()のドキュメントに戻ると

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

  • mode=Trueが訓練モード(training mode)
  • mode=Falseが評価モード(evaluation mode)

このように、model.train()model.eval()モデル(nn.Module)のモードを指定するメソッドです

訓練と評価でモードが変わる例:Dropout

nn.Module.train()のドキュメントより、モードの切り替えが働く例としてDropoutやBatchNormが挙げられています。

During training, randomly zeroes some of the elements of the input tensor with probability p.

訓練時と評価時の動きの違いは以下

the outputs are scaled by a factor of 1/(1-p) during training. This means that during evaluation the module simply computes an identity function.

  • 訓練時:Dropoutの出力は1/(1-p)倍される
    • 補足:pはTensorの要素をdropoutする確率で、デフォルト値は0.5
  • 評価時:Dropoutの出力はそのまま(identity function 恒等関数)

Dropoutって、訓練時と評価時で動きが違ったんだ〜!

ここまでの解説を読んで、model.train()model.eval()のようにメソッドを呼び分ける理由が腹落ちしました。
モデルの中に訓練と評価とで振る舞いが異なるレイヤがあったときに、その動きを指定しているんですね。

Evaluation Mode

https://pytorch.org/docs/stable/notes/autograd.html#evaluation-mode-nn-module-eval

You are responsible for calling model.eval() and model.train() if your model relies on modules such as torch.nn.Dropout and torch.nn.BatchNorm2d that may behave differently depending on training mode, for example, to avoid updating your BatchNorm running statistics on validation data.

BatchNorm2dについて、訓練時と評価時で振る舞いを変える理由が説明されていますね。

常にmodel.train()model.eval()を呼び出すのがオススメされるようです。

It is recommended that you always use model.train() when training and model.eval() when evaluating your model (validation/testing) even if you aren’t sure your model has training-mode specific behavior, because a module you are using might be updated to behave differently in training and eval modes.

使っているモジュール(nn.Module)が訓練時と評価時とで振る舞いを変えるようにアップデートされても、コード側でmodel.train()model.eval()を呼んでいればバグをもたらすようなことはありませんね。
訓練に入ることをtrain()メソッドで、評価に入ることをeval()メソッドで、モデルさんに教えてあげるんだ!

Future work:model.eval()とは別に、勾配計算をしないモードがある

Evaluation Modeのドキュメントより

Functionally, module.eval() (or equivalently module.train(False)) are completely orthogonal to no-grad mode and inference mode.

model.eval()は、no-gradモード推論モードとは直行する、とあります。

勾配計算をしないモードについては、「Evaluation Mode」と同じドキュメントに記載があります。
https://pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation
3つのGrad Modeの理解が積み残しです。

  • default mode
  • no-grad mode
  • inference mode

終わりに

モデルの訓練のコードに登場するmodel.train()model.eval()について調べました。

  • nn.Moduleの2つのモード:訓練モードと評価モード を開発者が指定する
  • モデルのレイヤには、訓練時と評価時とで振る舞いが異なるものがある(Dropout、BatchNorm2d)
    • これらに対して、model.train()model.eval()でモードを指定している

何をやっているかよく分かっていないコードから一転、開発者が指定する必要があると理解できました。
世のTrainerはmodel.train()model.eval()をラップしてくれているんだろうな


  1. LLMファインチューン例
  2. ニューラルネットワークの訓練の理解を深める一冊として、ゼロつく3巻を認識しています