Pytorch Lightning Transfer Learning on custom dataset
Although PyTorch is great but when task becomes complex there’s lots of small mistakes that can happen this is where PyTorch Lightning shines it structures your training and preparation such thatits both extensible for advance users and easy to use for beginners
For this Blog we will be using Butterfly Dataset which contains images of 50 different classes of butterfly.
This is pretty straight forward and dose not require much explanation
PyTorch Lightning has a clean way of handling data using classes, it has pre-built hooks which automatically get attached to the required method of the class and also are customizable.
Few things to note here prepare_data function is called only once during training while function setup is called once for each device in the cluster.
Lets say you have 8 cores in a TPU then prepare_data would be called once (generally for downloading data ) then setup would be called once for each 8 cores
This is where most of the PyTorch lightning work is done, PyTorch lightning has preconfigured hooks that allows us train model carefree for example it automatically save checkpoint after each epoch, implements early_stopping if loss metrics is available and automatically setups device for you this allows us to run same code on CPU,GPU and also TPU
Here we are using ResNet50 for 50 classes and Adam optimizer with fixed learning rate of 1e-4
Training is as simple as calling trainer.fit in PyTorch Lightning
Originally published at https://beginers.tech on December 4, 2020.