Tensorflow TPU and TFrecords

Before I start let me first say this “TPUs are fast and with fast i mean crazy fast the biggest bottleneck for TPU is its data loading process”

This article will be focused on how to combine TFrecod format with TPU processing to optimize data loading and minimize the training time

There are Five major steps to keep in mind when you are using TPU in Tensorflow

TPU initialization

Lets Start with TPU initialization, Its a very simple process and very important one the reason for this step is because TPU’s are multi node workers in the Cloud you dont have direct access to them like you have with GPU’s and CPU’s so you need to initialize network connection to each TPU processing node.

Below is an Example code to do so

tpu = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(tpu) tf.tpu.experimental.initialize_tpu_system(tpu)

Distribution Strategy

Now as there are multiple devices to work with you need some distribution strategy so that you can take advantage of multiple processing nodes the concept of distribution strategy is similar to what you might use for multi GPU training

Read More about distributed training here

# instantiate a distribution strategy tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu) AUTO = tf.data.experimental.AUTOTUNE REPLICAS = tpu_strategy.num_replicas_in_sync print(f'REPLICAS: {REPLICAS}')

Data Type

Currently, only the tf.float32, tf.int32, tf.bfloat16, and tf.bool data types are supported on the TPU. Other common data types, such as tf.uint8, tf.string, and tf.int64, must be converted to one of the supported data types during data pre-processing. I mostly use float32 and personally never tried bfloat32 before but it should work just fine.

This is by far the most important step in TPU training correct implementation can give you huge perfomance boost.

Tensorflow and TPU’s only work with data hosted on GCS atleast for now

If you are using existing dataset from tensorflow dataset use try_gcs=True flag to load data from GCS other wise my recomendation is to host your data in tfrecord format in GCS buckets I personally use Kaggle to host dataset and get GCS path from kaggle kernel that's just me you can always use your personal GCP account

below is the example how i loaded Ranzcer competition dataset from gcs in tfrecord format

overview of code structure

select tfrecord files to be used for data loading in my case i had 15 tfrecord files load it using tf.data.TFRecordDataset apithen map functions to decode labels and images and then convert it to particular data type finally before training configure perfomance optimizing and neccesary steps for example batch,prefetch,shuffle,repeat etc

Actoal Code


This is fairly easy stuff all you have do is to make sure what ever model or layers you are loading are hosted in GCS and model compilation should happen within the scope of distributed strategy

Example structure of code

def get_model():
load_all your model layers
compile it here
with tpu_strategy.scope():
model = get_model()

Thats It now you know how to make use of TPU’s along with TFRecord format to get faster training

Originally published at https://beginers.tech on January 27, 2021.



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store