はじめに
頑張れば、何かがあるって、信じてる。nikkieです。
最近v1.0が出たPyTorch Lightning、Getting startedのドキュメントに沿って週末に素振りしました。
目次
PyTorch Lightningとは
世はまさに大AI時代といった趣で、TensorFlowをはじめとする深層学習フレームワークが盛んに開発されています。
その中に、Facebook発の深層学習フレームワーク PyTorch があります。
PyTorch Lightning(以下、Lightning)はPyTorchのラッパーライブラリです。
PyTorchで頻繁に書くボイラープレートのコードがなくなるように設計されています。
機械学習スクリプトのエントリポイントが劇的に薄くなるので、「めっちゃイケてる✨」と心ときめきました。
実験に使うスクリプトのエントリポイントが長大になりがちなんですよね。。
if __name__ == "__main__": autoencoder = LitAutoEncoder() dataset = MNIST(Path.cwd(), download=True, transform=transforms.ToTensor()) train_loader = DataLoader(dataset) trainer = pl.Trainer() trainer.fit(autoencoder, train_loader)
10月のv1.0.0リリースに合わせて開発チームからBlogがポストされています。
- v1.0.0でLightningのAPIがfix
- Metricクラス追加
- 劇的に簡単なlogging!(
self.log
)
開発チームはGrid AIというモデルの訓練・サーブプラットフォームも開発していくそうです。
Lightningの哲学
LightningはPyTorchのラッパーなので、書き換え方を押さえなければなりません。
使い始めるのに必要な学習コストが気になりますが、リポジトリのREADMEで紹介されている原則(Lightning Philosophy)がヒントになると思いました。
Principle 4: Deep learning code should be organized into 4 distinct categories.
- Research code (the LightningModule).
- Engineering code (you delete, and is handled by the Trainer).
- Non-essential research code (logging, etc... this goes in Callbacks).
- Data (use PyTorch Dataloaders or organize them into a LightningDataModule).
先のコードでは、LitAutoEncoder
がLightningModule
クラスを継承したResearch codeです。
Engineering codeは全てTrainerに寄せています。
以下、素振りで作ったコードを備忘録代わりに残します。
今回の環境
$ sw_vers ProductName: Mac OS X ProductVersion: 10.14.6 BuildVersion: 18G4032 $ python -V Python 3.8.1 $ pip install pytorch-lightning torchvision # pytorch-lightning 1.0.3 # torchvision 0.7.0 # -- Lightningの依存により以下が入った --- # numpy 1.19.2 # torch 1.6.0 # tensorboard 2.3.0
Lightning in 2 stepsのコード
冒頭で紹介したページのMNISTの例(オートエンコーダー訓練)を進めました。
ページ冒頭の3分動画も参考にしています。
from argparse import ArgumentParser from pathlib import Path import pytorch_lightning as pl import torch import torch.nn.functional as F from torch import nn from torch.utils.data import DataLoader, random_split from torchvision import transforms from torchvision.datasets import MNIST
Research code: LightningModule
class LitAutoEncoder(pl.LightningModule): def __init__(self): super(LitAutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3), ) self.decoder = nn.Sequential( nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28), ) def forward(self, x): embedding = self.encoder(x) return embedding def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer def training_step(self, batch, batch_idx): x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) # forwardを呼び出す self(x) でもよい x_hat = self.decoder(z) loss = F.mse_loss(x_hat, x) self.log("train_loss", loss) return loss def validation_step(self, val_batch, batch_idx): x, y = val_batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) loss = F.mse_loss(x_hat, x) self.log("val_loss", loss) return loss
ポイントと思ったところ1
LightningModule
は1つのモデルではなく、複数のモデルからなるシステムを扱える2(この例でもencoderとdecoderを扱っている)forward
メソッドとtraining_step
メソッドを分離3(forward
はデプロイしたモデルの推論でも使える)training_step
メソッド(訓練)でlossを返している限りはLightningによって自動で最適化される(backward、optimizerの更新)
Engineering code: Trainer
- 訓練時の設定(GPUを使うかなど)を元にTraninerを初期化し、
- Trainerにモデル(
LightningModule
を継承したクラス)とデータを渡して訓練します
1のTranierの初期化は、コマンドラインからも指定できます4。
Trainer初期化の引数をいじらなくていいので便利そうです。
def parse_args(): parser = ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() trainer = pl.Trainer.from_argparse_args(args)
コマンドラインからTrainerに渡すと便利そうだと思ったオプション
--gpus 個数
to(device)
というPyTorchで頻出するコードから解放されたのはすごい!
--fast_dev_run
:コード全体のユニットテストオプション5- single batchのtrain, val, testデータで実行
- これにより、つまらないミスで訓練中に落ちる悲しい事件とさようならできる!
--limit_train_batches バッチ数
,--limit_val_batches バッチ数
:少量のデータで実行--max_epochs
,--max_steps
も設定できる--deterministic
:再現性の担保6seed_everything
関数と一緒に使う(numpy, torch, random, PYTHONHASHSEED全部のシードを固定)
Data: PyTorch Dataloader & LightningDataModule
train, val, testの各データの扱いをエントリポイントから、データモジュールクラスに移せるのが好感触です。
class MNISTDataModule(pl.LightningDataModule): def __init__(self, batch_size=32): super(MNISTDataModule, self).__init__() self.batch_size = batch_size def prepare_data(self): MNIST(Path.cwd(), train=True, download=True) MNIST(Path.cwd(), train=False, download=True) def setup(self, stage): if stage == "fit": mnist_train = MNIST( Path.cwd(), train=True, transform=transforms.ToTensor() ) self.mnist_train, self.mnist_val = random_split( mnist_train, [55000, 5000] ) if stage == "test": self.mnist_test = MNIST( Path.cwd(), train=False, transform=transforms.ToTensor() ) def train_dataloader(self): mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size) return mnist_train def val_dataloader(self): mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size) return mnist_val def test_dataloader(self): mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size) return mnist_test
if __name__ == "__main__": args = parse_args() autoencoder = LitAutoEncoder() dm = MNISTDataModule() trainer = pl.Trainer.from_argparse_args(args) trainer.fit(autoencoder, dm) # trainer.test(後ほど試す)
Non-essential research code: Callbacks
今回は手を動かしていませんが、early stoppingはCallbackで実現するそうです。
logging7はself.log
に指標名と一緒に渡すだけ。
プログレスバーやtensorboardで確認できます(log
の引数で調整もできます)。
tensorboard --logdir ./lightning_logs
※lightning_logs
ディレクトリにチェックポイントが保存されています
感想
エントリポイントが薄くなった高揚感でここまで書き上げました。
Lightningはすごく有用そうです!
ただ自動最適化はブラックボックス化でもあるので、PyTorchのボイラープレートコードも理解し、Lightningで賢く楽をしたいですね。
今後は実際に学習を回して素振りを繰り返し、ドキュメントで知ったことを知識に変えていこうと思います。
チュートリアルのMNIST 60000件はCPUでは訓練がサクサクいかないので、データを間引くか、他のデータセットを探すかですね。
訓練したモデルはtorchscriptなる形式で掃き出せるそうです。
この週末DeNA AIチャンネルで知って、よさそうに感じたstreamlitでdecoderをアプリ化しても面白そうだなと思いました。
-
superの呼び出し方は
super().__init__()
でも同じと確認しました。ref: https://docs.python.org/ja/3/library/functions.html#super↩ -
「PyTorch Lightning was designed to encapsulate a collection of models interacting together」(v1.0.0リリースのBlog)また、https://pytorch-lightning.readthedocs.io/en/latest/new-project.html#step-1-define-lightningmodule の「SYSTEM VS MODEL」↩
-
https://pytorch-lightning.readthedocs.io/en/latest/new-project.html#step-1-define-lightningmodule の「FORWARD vs TRAINING_STEP」↩
-
https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#trainer-in-python-scripts
argparse
を使っているので、-h
で全オプションが見られます↩ -
https://pytorch-lightning.readthedocs.io/en/stable/new-project.html#debugging↩
-
https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#reproducibility↩
-
https://pytorch-lightning.readthedocs.io/en/stable/new-project.html#logging↩