Pytorch Lightning Transfer Learning on custom dataset

Although PyTorch is great but when task becomes complex there’s lots of small mistakes that can happen this is where PyTorch Lightning shines it structures your training and preparation such thatits both extensible for advance users and easy to use for beginners

For this Blog we will be using Butterfly Dataset which contains images of 50 different classes of butterfly.

Imports

This is pretty straight forward and dose not require much explanation

Dataset

PyTorch Lightning has a clean way of handling data using classes, it has pre-built hooks which automatically get attached to the required method of the class and also are customizable.

Few things to note here prepare_data function is called only once during training while function setup is called once for each device in the cluster.

Lets say you have 8 cores in a TPU then prepare_data would be called once (generally for downloading data ) then setup would be called once for each 8 cores

Model

This is where most of the PyTorch lightning work is done, PyTorch lightning has preconfigured hooks that allows us train model carefree for example it automatically save checkpoint after each epoch, implements early_stopping if loss metrics is available and automatically setups device for you this allows us to run same code on CPU,GPU and also TPU

Here we are using ResNet50 for 50 classes and Adam optimizer with fixed learning rate of 1e-4

Training

Training is as simple as calling trainer.fit in PyTorch Lightning

Originally published at https://beginers.tech on December 4, 2020.

--

--

--

What should I put here ?

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

Product Thinking for Machine Learning

Who’s That Pokémon?! Building a Pokémon Identifier In Keras

Neural Network: Build from scratch, without frameworks 1

Getting Started with Machine Learning

What Is The Difference Between Deep Learning and Machine Learning?

Natural Language Processing Applications

Better Quantifying the Performance of Object Detection in Video

Introduction to Linear Regression

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Prashant Singh

Prashant Singh

What should I put here ?

More from Medium

Model Interpretation using GradCAM

Model Soups for Higher Performing Models

Bayesian Elbow Detection with tensorflow_probability

Model Interpretability Part 3: Local Model Agnostic Methods