In this blog post, we will be writing a simple convolutional neural network for classifying data in cifar-10 dataset. The code is available here and it is well commented, download and run it in case you want to see things in action.
we will start by importing the necessary libraries first.
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim
these libraries gives us necessary tools to code our neural network. “torch.nn” has all the essential layers like “Conv2d” and “Linear” prebuilt for us.
Note – There is a difference between “nn.Conv2d” and “nn.functional.conv2d”. The “nn.Conv2d” is meant to be used as a convolutional layer directly. However “nn.functional.conv2d” is meant to be used when you want your custom convolutional layer logic.
Now, we will use torchvision library to download and add transformations to our data.
import torchvision as tv import torchvision.transforms as transforms # our transformation pipeline transform = transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]) trainset = tv.datasets.CIFAR10(root="./data",train=True,download=True,transform=transform) dataloader = torch.utils.data.DataLoader(trainset,batch_size=4, shuffle=False, num_workers=4)
If you see we have add normalization to our dataset. The numbers are standard and i got the mean and standard deviation for normalization from here(let me know in comments what should be the correct value of standard deviation *wink* ).
It is time to define a convolutional neural network model now –
class OurModel(nn.Module): def__init__(self): super(OurModel,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.pool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1,16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
Every model in pytorch is a sub module of “nn.Module”. We have defined couple of convolutional layers here with the following tunings –
- conv1 – input channel = 3(because RGB image), output channel = 6(because 6 filters/kernels) and kernel size = 5.
- conv2 – input channel = 6 (because output of conv1 is 6), output channel = 16 (number of filters/kernels) and kernel size = 5.
Then we have got three fully connected layers. In the forward function we use “view” to flatten out the cube output from conv layers so that it can be passed to fully connected layers.
Good now let’s quickly define loss functions and optimizers.
net = OurModel() loss_func = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(),lr=0.001,weight_decay= 1e-6, momentum = 0.9, nesterov = True)
Let’s train our model
for epoch in range(2): running_loss= 0.0 for i,data inenumerate(dataloader,0): inputs, labels = data optimizer.zero_grad() # forward prop outputs = net(inputs) loss = loss_func(outputs, labels) # backprop loss.backward() # compute gradients optimizer.step() # update parameters # print statistics running_loss += loss.item() if i %2000==1999: # print every 2000 mini-batches print('[epoch: %d, minibatch: %5d] loss: %.3f'%(epoch +1, i +1, running_loss /2000)) running_loss = 0.0 print("Training finished!")
Good work! Our model has been trained. Now let’s test the prediction of the model
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') dataiter = iter(testloader) images, labels = dataiter.next() outputs = net(images) _, predicted = torch.max(outputs, 1) print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
Try the above code out and post in comments the result. I hope the model performed well.
If you like the blog then please follow and share the blog.
Happy Coding! 🙂