【初心者向け】PyTorch Lightningとは 

PyTorch Lightningとは 

PyTorch Lightning(略:PL)とは、PyTorchで書く必要な処理がメソッド化されたことで、コーディングを簡単にできるフレーワークです。PyTorch Lightningを用いるこで、開発スピードを高めることが出来ます。

引用:https://github.com/PyTorchLightning/pytorch-lightning

PyTorch Lightningをインストールする

PyTorch Lightningは「pip」コマンドでインストールが可能です。

pip install pytorch-lightning

PyTorch Lightningを読み込む

PyTorch Lightningを利用するには以下のようにimportすれば良いです。なおよく名前にはplとつけることが多いです。

import pytorch_lightning as pl

PyTorch LightningでMNISTを実施する

今回はGoogle Colab環境を用いて、MNIST分類モデルを作るコードをPLで書いてみます。

# 必要なライブラリの読み込み
import os # os.getcwd()を利用するためにインポート
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

# LightningModule (nn.Module のサブクラス) の定義
class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
        
    # 予測/推論アクションを定義
    def forward(self, x):
        embedding = self.encoder(x)
        return embedding

    # train ループを定義
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# 学習
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))

上記のように、PyTorch Lightningを用いると、40行ほどでコードが書けてしまいます。

おわりに

この記事では、PyTorch Lightningとは何か、インストール方法、MNISTの分類器を作成する一連のコードを紹介致しました。PyTorch Lightningなどを独学で勉強するときに、わからないことが出てきた場合、そこからなかなか先にスムーズに学習できないことがあります。

そのような方が機械学習を学ぶ上でオススメなのは、機械学習エンジニアからいつでも質問できる環境で学ぶことです。

AI Academy Bootcampなら、6ヶ月35,000円にてチャットで質問し放題の環境で、機械学習やデータ分析が学べるサービスを提供しております。
数十名在籍しているデータサイエンティストや機械学習エンジニアに質問し放題の環境でデータ分析、統計、機械学習、SQL等が学べます。AI人材に必要なスキルを効率よく体系的に身に付けたい方は是非ご検討ください。

コメントを残す