PyTorch Lightningとは
PyTorch Lightning(略:PL)とは、PyTorchで書く必要な処理がメソッド化されたことで、コーディングを簡単にできるフレーワークです。PyTorch Lightningを用いるこで、開発スピードを高めることが出来ます。
引用:https://github.com/PyTorchLightning/pytorch-lightning
PyTorch Lightningをインストールする
PyTorch Lightningは「pip」コマンドでインストールが可能です。
pip install pytorch-lightning
PyTorch Lightningを読み込む
PyTorch Lightningを利用するには以下のようにimportすれば良いです。なおよく名前にはplとつけることが多いです。
import pytorch_lightning as pl
PyTorch LightningでMNISTを実施する
今回はGoogle Colab環境を用いて、MNIST分類モデルを作るコードをPLで書いてみます。
# 必要なライブラリの読み込み
import os # os.getcwd()を利用するためにインポート
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
# LightningModule (nn.Module のサブクラス) の定義
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
# 予測/推論アクションを定義
def forward(self, x):
embedding = self.encoder(x)
return embedding
# train ループを定義
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# 学習
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
上記のように、PyTorch Lightningを用いると、40行ほどでコードが書けてしまいます。
おわりに
この記事では、PyTorch Lightningとは何か、インストール方法、MNISTの分類器を作成する一連のコードを紹介致しました。PyTorch Lightningなどを独学で勉強するときに、わからないことが出てきた場合、そこからなかなか先にスムーズに学習できないことがあります。
そのような方が機械学習を学ぶ上でオススメなのは、機械学習エンジニアからいつでも質問できる環境で学ぶことです。
AI Academy Bootcampなら、6ヶ月35,000円にてチャットで質問し放題の環境で、機械学習やデータ分析が学べるサービスを提供しております。
数十名在籍しているデータサイエンティストや機械学習エンジニアに質問し放題の環境でデータ分析、統計、機械学習、SQL等が学べます。AI人材に必要なスキルを効率よく体系的に身に付けたい方は是非ご検討ください。