Which one should I choose as a beginner in Deep Learning pytorch or tensorflow?

I know I know. It is really daunting to start learning deep learning today. I remember August-September of 2017 when I was starting to tiptoe in the waters of deep learning. I still remember reading my first paper and how formidable experience it was, specially for an autodidact programmer like me. I was so confused then about the frameworks as you are now. Stay with me, I will try to talk about both the frameworks in a beginner friendly way and then you can choose your weapon.

“The master is a beginner that never gave up.” ― Avina Celeste



Origin —

Let’s glance over the origination of these frameworks.

Tensorflow was developed by Google Brain and Google actively uses it to both prototype the models, i.e experimentation and also for production.

Pytorch has its origin from a lua-based Torch framework which was developed and used at Facebook. However it is not a wrapper like keras, pytorch has been rewritten.

Let’s see the code —

I know how it feels to see the redundant information which I provided above. But believe me when you become a deep learning pro and I know you will, you will realize that origins of framework matter. You will realize that there is something known as dynamic and static computational graph and then you would be glad to yourself that you knew the origins of these framework.

Moving on, We are coders right? We want to see how something is implemented in a framework and that would help us to decide which framework will suit us.

So, here is a simple Convolutional Neural Network written in pytorch —

import torch
import torch.nn as nn
import torch.nn.functional as F
# A ConvNet for MNIST dataset
class Net(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320) #flattens tensors out to.
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

You can see it is intuitive and really quick to write a neural network in pytorch. Pytorch is usually great for quick prototyping. You can check this video out to know more about how to implement a convolutional neural network in Pytorch —

Now Let us see how can we write a simple neural network in tensorflow —

import tensorflow as tf
# A ConvNet for MNIST dataset
#define graph inputs
X = tf.placeholder(tf.float32, [None, num_input])
Y = tf.placeholder(tf.float32, [None, num_classes])
keep_prob = tf.placeholder(tf.float32) # dropouts
# adding some wrappers
def conv2d(x, W, b, strides=1):    
    # Conv2D wrapper, with bias and relu activation    
    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], 
    x = tf.nn.bias_add(x, b)    
    return tf.nn.relu(x)
def maxpool2d(x, k=2):
    # MaxPool2D wrapper    
    return tf.nn.max_pool(x, ksize=[1, k, k, 1], 
                          strides=[1, k, k, 1],             
# Defining model
def Net(x, weights, biases, dropout):
    x = tf.reshape(x, shape=[-1, 28, 28, 1])
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
    conv1 = maxpool2d(conv1, k=2)
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
    conv2 = maxpool2d(conv2, k=2)
    fc1 = tf.reshape(conv2, 
    fc1 = tf.add(tf.matmul(fc1, weights['wd1']), biases['bd1'])
    fc1 = tf.nn.relu(fc1)
    fc1 = tf.nn.dropout(fc1, dropout)
    out = tf.add(tf.matmul(fc1, weights['out']), biases['out'])
    return out

Looks complicated right? But it gives you way more control. Although, this is not a fair comparison because tensorflow has included a library known as tf.estimator. You can check this link out for more details about implementing ConvNets using tf.estimator- https://www.tensorflow.org/guide/estimators

Generally people say that pytorch is more pythonic and code looks clean. If you want to see really advanced opinions for both the sides the I’d recommend this link — https://www.reddit.com/r/MachineLearning/comments/7ziagd/d_discussion_on_pytorch_vs_tensorflow/

Distributed Model Training —

This is a difficult one to simplify. However I will try.

Assume that you have multiple GPUs or cpus to work on(Lucky you! I am yet to buy a GPU 🙁) then you can distribute the training of your model to reduce the training time.

Now pytorch uses MPI, GLOO and NCCL as communication protocols to communicate with various devices(CPUs and GPUs).

Tensorflow uses gRPC as communication protocol. You can also run tensorflow on HDFS and S3. However if you want to use MPI protocol with tensorflow the you can take advantage of Horovod.

Visualization —

 Both pytorch and tensorflow uses tensorboard for visualizations. Pytorch doesn’t have its own visualization tool yet.

Community, Documentation and Reaching out —

I believe both communities of tensorflow and pytorch are really accepting and helpful. I found the documentation of pytorch to be much neater as compared to tensorflow.

Both Google Brain and Pytorch Developers do live streaming and roadshows to reach out to developers.

This is me attending First Tensorflow RoadShow in Bangalore. You can spot me as a guy left to the the banner which crowd is holding. I am facing sideways and wearing a black hoodie.

Tensorflow RoadShow!

I hope I was able to help you out. The journey of deep learning is difficult but it is very rewarding. Feel free to contact me! 🙂

Happy Coding!

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s