Shinonome Tech Blog

株式会社Shinonomeの技術ブログ
11 min read

Pytorchでピザ判定機を作る

こんにちは、Playgroundのデータコースに所属している安藤太一です。昨日ピザを食べたくなったのですが、目の前の食べ物がピザかピザじゃないか分からなくなってしまいました。なので今回はPytorchを用いて画像に写っている物がピザかピザじゃないかを判定する深層学習モデルを作っていこうと思います。

Pytorchでピザ判定機を作る

こんにちは、Playgroundのデータコースに所属している安藤太一です。昨日ピザを食べたくなったのですが、目の前の食べ物がピザかピザじゃないか分からなくなってしまいました。なので今回はPytorchを用いて画像に写っている物がピザかピザじゃないかを判定する深層学習モデルを作っていこうと思います。

food_pizza-1

注意: 初めての画像分類タスクなので間違いがあるかもしれません。参考にされる際は十分注意してください。

目次

  • ライブラリ類の準備
  • データセットの準備
  • データローダーの作成
  • 学習
  • 結果確認
  • 参考文献

ライブラリ類の準備

今回使うライブラリやパッケージをインストールしていきます。
主に使うライブラリは、torch、opendatasets、torchvisionです。

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
!pip install -qqq torchviz
!pip install -qqq opendatasets
  Building wheel for torchviz (setup.py) ... [?25l[?25hdone
import torch # pytorch のテンソル構造など全般
from torch import optim
from torch import nn # ネットワークや各種レイヤー
from torch.nn import functional #より詳しいレイヤー
from torch.utils.data import DataLoader
from torchvision import datasets # 画像データセットのモジュール
from torchvision import transforms # 画像をTorchのテンソルに変換する
from torchviz import make_dot
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import opendatasets as op
from torchvision import models
from torch.optim import Adam, lr_scheduler
from sklearn.metrics import accuracy_score
import time
from datetime import timedelta
import shutil

データセットの準備

今回使うデータセットをダウンロードします。今回はKaggleで公開されている、「Pizza or Not Pizza?」というデータセットを使います。opendatasetsというライブラリを用いてダウンロードします。
op.download()でダウンロードができますが、kaggleからのダウンロードの場合にはkaggle apiで使われるkaggle.jsonがカレントディレクトリに無いとエラーが出るので注意が必要です。

shutil.copyfile("/content/drive/MyDrive/kaggle/kaggle.json", "/content/kaggle.json") #kaggle.jsonをコピーしてカレントディレクトリに移動
op.download("https://www.kaggle.com/datasets/carlosrunner/pizza-not-pizza")
Downloading pizza-not-pizza.zip to ./pizza-not-pizza


100%|██████████| 101M/101M [00:01<00:00, 74.3MB/s] 

ディレクトリ構造は以下のようになります。not_pizzaにピザでない食べ物の画像、pizzaにピザの画像が入っています。

└── pizza_not_pizza
    ├── food101_subset.py
    ├── not_pizza [983 entries exceeds filelimit, not opening dir]
    └── pizza [983 entries exceeds filelimit, not opening dir]

データローダーの作成

Pytorchでミニバッチ学習を行うためには、DatasetとDataloaderの作成が必要になります。
データがダウンロードできたら、まずDatasetを作成していきます。最初にtorchのモジュールであるdatasetsのImageFolderメソッドを用いてデータを読み込みます。
ディレクトリがtrainデータとtestデータで別れている構造の場合、読み込んだままDatasetにすることができますが、今回は別れていないため自分で分ける必要があります。そのためにDatasetのクラスを改造してサブクラスを定義します。(インデックスでデータを分けれるようにしています。)

data_path = "/content/pizza-not-pizza/pizza_not_pizza"
# os.rmdir(data_path + "/.ipynb_checkpoints")

data = datasets.ImageFolder("/content/pizza-not-pizza/pizza_not_pizza")
#Datasetのサブクラス(trainとvalidに分けるために必要)
class MySubset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __getitem__(self, idx):
        img, label = self.dataset[self.indices[idx]]
        if self.transform:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.indices)

サブクラスを定義できたら次は画像を読み込んだ時に行う変換をtransforms.Composeの中に定義します。基本的にはテンソルに変換するだけですが、今回訓練データでは、正規化を始め、Data augumentationというモデルに頑健性を持たせるための変換(例えば画像を反転させたり、ランダムに一部を黒く塗りつぶしたり)を定義しています。また今回使うVGG_19bnという事前学習済みモデルが画像サイズ244*244で学習されたモデルであるので、そのサイズに合わせるようにリサイズする変換も盛り込んであります。

#訓練データ用
#正規化、反転、RandomErasing←ランダムで一部を消すやつ
train_transform = transforms.Compose([
                                      transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(0.5, 0.5),
                                      transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 0.3), value=0, inplace=False)                        
])

#検証データ用
#正規化だけ
val_transform = transforms.Compose([
                                     transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(0.5, 0.5)
])

transformsが定義できたら実際にDataloaderを作っていきます。data→Dataset→Dataloaderの順で定義していきます。この時に、データを訓練データ、検証データ、テストデータに分けます。

batch_size = 32
train_size = 0.8
valid_size = 0.1

train_idx = int(train_size * len(data))
valid_idx = int((train_size + valid_size) * len(data)) 
indices = np.arange(len(data))

train_dataset = MySubset(data, indices[:train_idx], train_transform)
val_dataset = MySubset(data, indices[train_idx:valid_idx], val_transform)
test_dataset = MySubset(data, indices[valid_idx:], val_transform)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, drop_last=True)

今一度データのクラスを確認しておきます。

data.classes
['not_pizza', 'pizza']

ちゃんとデータが分けられているかどうかを確認するために、各データセットの大きさを確認します。

print(f"train length: {len(train_dataset)}")
print(f"valid length: {len(val_dataset)}")
print(f"test length: {len(test_dataset)}")
train length: 1572
valid length: 197
test length: 197

試しに画像を確認しておきます。

plt.imshow(data[1000][0])
plt.title(data.classes[data[1000][1]])
plt.show()

blog2_vgg19_22_0-1

plt.imshow(data[1][0])
plt.title(data.classes[data[1][1]])
plt.show()

blog2_vgg19_23_0

学習

次に実際に学習をしていきます。事前学習済みモデルを読み込んだ後に、最終層を今回分類したい数に合わせます。今回は2クラス分類なので最終層のノードの数を2に設定します。(1つでもいいかも?)
また学習させる準備として、モデルをGPUに乗せる、損失関数の定義、最適化手法の定義、学習率を変化させるスケジューラーの設定を行います。

net = models.vgg19_bn(pretrained=True)

torch.manual_seed(42)
torch.cuda.manual_seed(42)

#最終ノードの出力を2にする
in_features = net.classifier[6].in_features
net.classifier[6] = nn.Linear(in_features, 2)

net.avgpool = nn.Identity()

device = "cuda:0" if torch.cuda.is_available() else "cpu"

net = net.to(device)

lr = 0.001

loss_CE = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)

scheduler = lr_scheduler.LinearLR(optimizer)
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=VGG19_BN_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_BN_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" to /root/.cache/torch/hub/checkpoints/vgg19_bn-c79401a0.pth



  0%|          | 0.00/548M [00:00<?, ?B/s]
n_epochs = 20

ようやく学習をします。学習の流れとしては、予測→損失を出す→勾配を求め逆伝播する→パラメータ更新→検証データで精度を確認のサイクルをぐるぐる回す感じです。

loss_list = []
acc_train_list = []
acc_valid_list = []
acc_list = []
for epoch in range(n_epochs):
    start_time = time.time()
    net.train()
    loss_train = 0

    for data in train_loader:

        inputs, labels = data

        optimizer.zero_grad()

        output = net.forward(inputs.to(device))

        loss = loss_CE(output, labels.to(device))

        loss.backward()

        optimizer.step()

        loss_train += loss.tolist()

        predicted = torch.max(output, 1)[1].cpu()
        acc_train_list.append(((predicted == labels).sum() / len(labels)).item())


    scheduler.step()

    net.eval()
    loss_valid = 0

    for data in val_loader:
        inputs, labels = data

        output = net.forward(inputs.to(device))

        loss = loss_CE(output, labels.to(device))

        loss_valid += loss.tolist()

        predicted = torch.max(output, 1)[1].cpu()
        acc_valid_list.append(((predicted == labels).sum() / len(labels)).item())

    delta = timedelta(seconds=time.time()-start_time)
    loss_list.append([loss_train, loss_valid])
    train_acc = (sum(acc_train_list) / len(acc_train_list))
    valid_acc = (sum(acc_valid_list) / len(acc_valid_list))
    acc_list.append([train_acc, valid_acc])
    print(epoch, f"\ttrain_loss: {loss_train:.5f}", f"\tvalid_loss: {loss_valid:.5f}", f"\ttime_delta: {delta}", f"\ttrain_acc: {train_acc:.5f}",  f"\tvalid_acc: {valid_acc:.5f}")
0 	train_loss: 25.60905 	valid_loss: 2.32306 	time_delta: 0:00:32.902674 	train_acc: 0.73533 	valid_acc: 0.84896
1 	train_loss: 15.03263 	valid_loss: 0.65057 	time_delta: 0:00:25.437050 	train_acc: 0.80293 	valid_acc: 0.90365
2 	train_loss: 11.17252 	valid_loss: 0.64828 	time_delta: 0:00:25.422561 	train_acc: 0.83695 	valid_acc: 0.92535
3 	train_loss: 9.86615 	valid_loss: 0.57948 	time_delta: 0:00:25.434066 	train_acc: 0.85778 	valid_acc: 0.93620
4 	train_loss: 10.04306 	valid_loss: 0.50447 	time_delta: 0:00:25.385084 	train_acc: 0.86862 	valid_acc: 0.94375
5 	train_loss: 8.65810 	valid_loss: 0.71927 	time_delta: 0:00:25.477165 	train_acc: 0.87883 	valid_acc: 0.94618
6 	train_loss: 7.56490 	valid_loss: 0.69913 	time_delta: 0:00:25.491099 	train_acc: 0.88757 	valid_acc: 0.94866
7 	train_loss: 6.41282 	valid_loss: 0.44147 	time_delta: 0:00:25.409948 	train_acc: 0.89525 	valid_acc: 0.95182
8 	train_loss: 6.63811 	valid_loss: 0.85102 	time_delta: 0:00:25.473457 	train_acc: 0.90086 	valid_acc: 0.95139
9 	train_loss: 7.01148 	valid_loss: 0.49679 	time_delta: 0:00:25.455468 	train_acc: 0.90523 	valid_acc: 0.95365
10 	train_loss: 5.86267 	valid_loss: 0.75677 	time_delta: 0:00:25.388607 	train_acc: 0.90961 	valid_acc: 0.95360
11 	train_loss: 5.87784 	valid_loss: 0.47422 	time_delta: 0:00:25.421913 	train_acc: 0.91327 	valid_acc: 0.95530
12 	train_loss: 4.91787 	valid_loss: 0.59399 	time_delta: 0:00:25.385410 	train_acc: 0.91690 	valid_acc: 0.95633
13 	train_loss: 4.95380 	valid_loss: 0.67660 	time_delta: 0:00:25.490671 	train_acc: 0.92010 	valid_acc: 0.95796
14 	train_loss: 4.49343 	valid_loss: 0.26393 	time_delta: 0:00:25.411961 	train_acc: 0.92317 	valid_acc: 0.95972
15 	train_loss: 4.03241 	valid_loss: 0.75706 	time_delta: 0:00:25.433537 	train_acc: 0.92578 	valid_acc: 0.96029
16 	train_loss: 5.39799 	valid_loss: 0.30651 	time_delta: 0:00:25.366988 	train_acc: 0.92793 	valid_acc: 0.96140
17 	train_loss: 4.21324 	valid_loss: 0.60435 	time_delta: 0:00:25.400234 	train_acc: 0.93006 	valid_acc: 0.96209
18 	train_loss: 3.90613 	valid_loss: 0.69240 	time_delta: 0:00:25.454599 	train_acc: 0.93213 	valid_acc: 0.96272
19 	train_loss: 3.04218 	valid_loss: 0.49556 	time_delta: 0:00:25.409091 	train_acc: 0.93437 	valid_acc: 0.96380

モデルが学習できたら、保存します。

torch.save(net.state_dict(), "/content/drive/MyDrive/Mymodel/PIzzaNet.pkl")
del net, inputs, labels, loss, output
torch.cuda.empty_cache()

学習結果を確認していきます。下の図が損失のグラフ、更に下が正解率のグラフとなっています。正解率を見る限り、エポック数をもう少し増やしても良かったかもしれません。またどちらの図も訓練データの精度のほうが検証用データの精度より低いのは、Data augumentationをしているからだと考えられます。

plt.plot(np.array(loss_list)[:, 0], label="train")
plt.plot(np.array(loss_list)[:, 1], label="valid")
plt.legend()
plt.grid()
plt.xlabel("Epoch")
plt.ylabel("Cross Entropy")
plt.title("loss")
Text(0.5, 1.0, 'loss')

blog2_vgg19_32_1

plt.plot(np.array(acc_list)[:, 0], label="train")
plt.plot(np.array(acc_list)[:, 1], label="valid")
plt.legend()
plt.grid()
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy")
Text(0.5, 1.0, 'Accuracy')

blog2_vgg19_33_1

結果

保存したモデルを読み込んで、実際にテストデータで精度を確認していきましょう。

net = models.vgg19_bn(pretrained=False)
#最終ノードの出力を2にする
device = "cuda:0" if torch.cuda.is_available() else "cpu"
in_features = net.classifier[6].in_features
net.classifier[6] = nn.Linear(in_features, 2)
net.load_state_dict(torch.load("/content/drive/MyDrive/Mymodel/PIzzaNet.pkl"))
net.to(device)
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)





VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): ReLU(inplace=True)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): ReLU(inplace=True)
    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (22): ReLU(inplace=True)
    (23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (25): ReLU(inplace=True)
    (26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (27): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (32): ReLU(inplace=True)
    (33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (35): ReLU(inplace=True)
    (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (38): ReLU(inplace=True)
    (39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (42): ReLU(inplace=True)
    (43): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (45): ReLU(inplace=True)
    (46): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (48): ReLU(inplace=True)
    (49): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (51): ReLU(inplace=True)
    (52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=2, bias=True)
  )
)

テストデータに対して予測させ、精度を確認します。

test_output = []
test_label = []
net.eval()
for data in test_loader:
    inputs, labels = data
    output = net.forward(inputs.to(device))
    test_output.extend(output.tolist())
    test_label.extend(labels.tolist())
predictions = np.argmax(test_output, axis=1)
accuracy_score(test_label, predictions)
0.9739583333333334

正解率は約97%でした。これならピザかピザじゃないか判定できそうです。助かりました。

#参考文献
"""
前処理系: https://pystyle.info/pytorch-list-of-transforms/
データセット: https://www.kaggle.com/datasets/carlosrunner/pizza-not-pizza
データセットを分けるためのコード: https://pystyle.info/pytorch-split-dataset/
"""