nikkie-ftnextの日記

イベントレポートや読書メモを発信

週末ログ | PyTorch Lightningの"Lightning in 2 steps"を触りました⚡️

はじめに

頑張れば、何かがあるって、信じてる。nikkieです。
最近v1.0が出たPyTorch Lightning、Getting startedのドキュメントに沿って週末に素振りしました。

目次

PyTorch Lightningとは

世はまさに大AI時代といった趣で、TensorFlowをはじめとする深層学習フレームワークが盛んに開発されています。
その中に、Facebook発の深層学習フレームワーク PyTorch があります。
PyTorch Lightning(以下、Lightning)はPyTorchのラッパーライブラリです。

PyTorchで頻繁に書くボイラープレートのコードがなくなるように設計されています。
機械学習スクリプトエントリポイントが劇的に薄くなるので、「めっちゃイケてる✨」と心ときめきました。
実験に使うスクリプトのエントリポイントが長大になりがちなんですよね。。

if __name__ == "__main__":
    autoencoder = LitAutoEncoder()
    
    dataset = MNIST(Path.cwd(), download=True, transform=transforms.ToTensor())
    train_loader = DataLoader(dataset)

    trainer = pl.Trainer()
    trainer.fit(autoencoder, train_loader)

10月のv1.0.0リリースに合わせて開発チームからBlogがポストされています。

  • v1.0.0でLightningのAPIがfix
  • Metricクラス追加
  • 劇的に簡単なlogging!(self.log

開発チームはGrid AIというモデルの訓練・サーブプラットフォームも開発していくそうです。

Lightningの哲学

LightningはPyTorchのラッパーなので、書き換え方を押さえなければなりません。
使い始めるのに必要な学習コストが気になりますが、リポジトリのREADMEで紹介されている原則(Lightning Philosophy)がヒントになると思いました。

Principle 4: Deep learning code should be organized into 4 distinct categories.

  • Research code (the LightningModule).
  • Engineering code (you delete, and is handled by the Trainer).
  • Non-essential research code (logging, etc... this goes in Callbacks).
  • Data (use PyTorch Dataloaders or organize them into a LightningDataModule).

先のコードでは、LitAutoEncoderLightningModuleクラスを継承したResearch codeです。
Engineering codeは全てTrainerに寄せています。

以下、素振りで作ったコードを備忘録代わりに残します。

今回の環境

$ sw_vers
ProductName:    Mac OS X
ProductVersion: 10.14.6
BuildVersion:   18G4032
$ python -V
Python 3.8.1
$ pip install pytorch-lightning torchvision
# pytorch-lightning      1.0.3
# torchvision            0.7.0

# -- Lightningの依存により以下が入った ---
# numpy                  1.19.2
# torch                  1.6.0
# tensorboard            2.3.0

Lightning in 2 stepsのコード

冒頭で紹介したページのMNISTの例(オートエンコーダー訓練)を進めました。
ページ冒頭の3分動画も参考にしています。

from argparse import ArgumentParser
from pathlib import Path

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST

Research code: LightningModule

class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super(LitAutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 64),
            nn.ReLU(),
            nn.Linear(64, 3),
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28 * 28),
        )

    def forward(self, x):
        embedding = self.encoder(x)
        return embedding

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)  # forwardを呼び出す self(x) でもよい
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("val_loss", loss)
        return loss

ポイントと思ったところ1

  • LightningModuleは1つのモデルではなく、複数のモデルからなるシステムを扱える2(この例でもencoderとdecoderを扱っている)
  • forwardメソッドとtraining_stepメソッドを分離3forwardはデプロイしたモデルの推論でも使える)
  • training_stepメソッド(訓練)でlossを返している限りはLightningによって自動で最適化される(backward、optimizerの更新)

Engineering code: Trainer

  1. 訓練時の設定(GPUを使うかなど)を元にTraninerを初期化し、
  2. Trainerにモデル(LightningModuleを継承したクラス)とデータを渡して訓練します

1のTranierの初期化は、コマンドラインからも指定できます4
Trainer初期化の引数をいじらなくていいので便利そうです。

def parse_args():
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    
    trainer = pl.Trainer.from_argparse_args(args)

コマンドラインからTrainerに渡すと便利そうだと思ったオプション

  • --gpus 個数
    • to(device)というPyTorchで頻出するコードから解放されたのはすごい!
  • --fast_dev_run:コード全体のユニットテストオプション5
    • single batchのtrain, val, testデータで実行
    • これにより、つまらないミスで訓練中に落ちる悲しい事件とさようならできる!
  • --limit_train_batches バッチ数, --limit_val_batches バッチ数:少量のデータで実行
  • --max_epochs, --max_stepsも設定できる
  • --deterministic:再現性の担保6
    • seed_everything関数と一緒に使う(numpy, torch, random, PYTHONHASHSEED全部のシードを固定)

Data: PyTorch Dataloader & LightningDataModule

train, val, testの各データの扱いをエントリポイントから、データモジュールクラスに移せるのが好感触です。

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super(MNISTDataModule, self).__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        MNIST(Path.cwd(), train=True, download=True)
        MNIST(Path.cwd(), train=False, download=True)

    def setup(self, stage):
        if stage == "fit":
            mnist_train = MNIST(
                Path.cwd(), train=True, transform=transforms.ToTensor()
            )
            self.mnist_train, self.mnist_val = random_split(
                mnist_train, [55000, 5000]
            )
        if stage == "test":
            self.mnist_test = MNIST(
                Path.cwd(), train=False, transform=transforms.ToTensor()
            )

    def train_dataloader(self):
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
        return mnist_train

    def val_dataloader(self):
        mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
        return mnist_val

    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
        return mnist_test
if __name__ == "__main__":
    args = parse_args()

    autoencoder = LitAutoEncoder()

    dm = MNISTDataModule()

    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(autoencoder, dm)
    # trainer.test(後ほど試す) 

Non-essential research code: Callbacks

今回は手を動かしていませんが、early stoppingはCallbackで実現するそうです。

logging7self.logに指標名と一緒に渡すだけ。
プログレスバーtensorboardで確認できます(logの引数で調整もできます)。

tensorboard --logdir ./lightning_logs

lightning_logsディレクトリにチェックポイントが保存されています

感想

エントリポイントが薄くなった高揚感でここまで書き上げました。
Lightningはすごく有用そうです!
ただ自動最適化はブラックボックス化でもあるので、PyTorchのボイラープレートコードも理解し、Lightningで賢く楽をしたいですね。

今後は実際に学習を回して素振りを繰り返し、ドキュメントで知ったことを知識に変えていこうと思います。
チュートリアルのMNIST 60000件はCPUでは訓練がサクサクいかないので、データを間引くか、他のデータセットを探すかですね。

訓練したモデルはtorchscriptなる形式で掃き出せるそうです。
この週末DeNA AIチャンネルで知って、よさそうに感じたstreamlitでdecoderをアプリ化しても面白そうだなと思いました。