Pytorch Lightning Transfer Learning on custom dataset

Although PyTorch is great but when task becomes complex there’s lots of small mistakes that can happen this is where PyTorch Lightning shines it structures your training and preparation such thatits both extensible for advance users and easy to use for beginners

For this Blog we will be using Butterfly Dataset which contains images of 50 different classes of butterfly.

Imports

This is pretty straight forward and dose not require much explanation

Dataset

PyTorch Lightning has a clean way of handling data using classes, it has pre-built hooks which automatically get attached to the required method of the class and also are customizable.

Few things to note here prepare_data function is called only once during training while function setup is called once for each device in the cluster.

Lets say you have 8 cores in a TPU then prepare_data would be called once (generally for downloading data ) then setup would be called once for each 8 cores

Model

This is where most of the PyTorch lightning work is done, PyTorch lightning has preconfigured hooks that allows us train model carefree for example it automatically save checkpoint after each epoch, implements early_stopping if loss metrics is available and automatically setups device for you this allows us to run same code on CPU,GPU and also TPU

Here we are using ResNet50 for 50 classes and Adam optimizer with fixed learning rate of 1e-4

Training

Training is as simple as calling trainer.fit in PyTorch Lightning

Originally published at https://beginers.tech on December 4, 2020.

What should I put here ?