In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
from tqdm import tqdm
train = datasets.MNIST('', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor()
]))
test = datasets.MNIST('', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor()
]))
/Users/ksv/miniforge3/envs/pytorch/lib/python3.8/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:180.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
In [2]:
BATCH_SIZE = 10
trainset = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=BATCH_SIZE, shuffle=False)
In [3]:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 128)
self.fc4 = nn.Linear(128, 64)
self.fc5 = nn.Linear(64, 32)
self.fc6 = nn.Linear(32, 32)
self.fc7 = nn.Linear(32, 32)
self.fc8 = nn.Linear(32, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
x = F.relu(self.fc5(x))
x = F.relu(self.fc6(x))
x = F.relu(self.fc7(x))
x = self.fc8(x)
return x
net = Net()
In [4]:
import torch.optim as optim
loss_fn = nn.CrossEntropyLoss()
optimizer= optim.Adam(net.parameters(), lr=0.001)
In [5]:
# Training
for epoch in range(3):
for data in tqdm(trainset):
X, y = data
y_hat = net(X.view(-1, 28*28)) # predict
loss = loss_fn(y_hat, y) # how wrong we are
optimizer.zero_grad()
loss.backward() # compute gradients
optimizer.step() # adjust weights
100%|██████████| 6000/6000 [00:16<00:00, 370.45it/s] 100%|██████████| 6000/6000 [00:15<00:00, 378.49it/s] 100%|██████████| 6000/6000 [00:17<00:00, 338.50it/s]
In [6]:
correct = 0
total = 0
with torch.no_grad():
for data in tqdm(testset):
X, y = data
y_hat = net(X.view(-1, 28*28))
for idx, i in enumerate(y_hat):
if torch.argmax(i) == y[idx]:
correct += 1
total += 1
print(f'Accuracy: {correct/total * 100}%')
100%|██████████| 1000/1000 [00:00<00:00, 1808.32it/s]
Accuracy: 96.7%
In [7]:
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv_sequential = nn.Sequential(
nn.Conv2d(1, 32, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.fc_sequential = nn.Sequential(
nn.Linear(64*4*4, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.conv_sequential(x)
x = x.view(-1, 64*4*4)
x = self.fc_sequential(x)
return F.log_softmax(x, dim=1)
In [8]:
# train the CNN
cnn = CNN()
loss_fn_cnn = nn.CrossEntropyLoss()
optimizer_cnn = optim.Adam(cnn.parameters(), lr=0.001)
for epoch in range(3):
for data in tqdm(trainset):
X, y = data
cnn.zero_grad()
y_hat = cnn(X)
loss = loss_fn_cnn(y_hat, y)
loss.backward()
optimizer_cnn.step()
0%| | 0/6000 [00:00<?, ?it/s]/Users/ksv/miniforge3/envs/pytorch/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) 100%|██████████| 6000/6000 [00:55<00:00, 109.09it/s] 100%|██████████| 6000/6000 [00:51<00:00, 116.81it/s] 100%|██████████| 6000/6000 [00:50<00:00, 119.51it/s]
In [9]:
# test the CNN
correct = 0
total = 0
with torch.no_grad():
for data in tqdm(testset):
X, y = data
y_hat = cnn(X)
for idx, i in enumerate(y_hat):
if torch.argmax(i) == y[idx]:
correct += 1
total += 1
print(f'Accuracy: {correct/total * 100}%')
100%|██████████| 1000/1000 [00:02<00:00, 439.42it/s]
Accuracy: 99.00999999999999%
In [ ]:
In [ ]: