python|stn在mnist上的实现

stn在mnist上的实现 个人博客 - https://cxy-sky.github.io/
代码参考来源:PyTorch框架实战系列(3)——空间变换器网络STN_Daniel Yuz的博客-CSDN博客
理论:Pytorch中的仿射变换(affine_grid)_liangbaqiang的博客-CSDN博客
【python|stn在mnist上的实现】详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn

? 图片显示用的是matplotlib,自己没下opencv.
CNN
import torch from torch import nn, optimclass CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.cnn = nn.Sequential( nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=4), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=3), ) self.linear = nn.Sequential( nn.Dropout2d(0.5), nn.Linear(512, 10) )def forward(self, x): x = self.cnn(x) x = x.view(x.size()[0], -1) # print(x.size()) x = self.linear(x) return xif __name__ == '__main__': model = CNN() x = torch.rand(1, 1, 28, 28) print(model) y = model(x) print(y)

STN
import torch from torch import nnclass STN(nn.Module): def __init__(self): super(STN, self).__init__() self.location_cov = nn.Sequential( nn.Conv2d(1, 8, kernel_size=7), nn.ReLU(), nn.MaxPool2d(2, stride=2), nn.Conv2d(8, 10, kernel_size=5), nn.ReLU(), nn.MaxPool2d(2, stride=2), )self.localization_linear = nn.Sequential( nn.Linear(in_features=10 * 3 * 3, out_features=32), nn.ReLU(), nn.Linear(in_features=32, out_features=2 * 3) )self.localization_linear[2].weight.data.zero_() self.localization_linear[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))self.cnn = nn.Sequential( nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=4), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=3), ) self.linear = nn.Sequential( nn.Dropout2d(0.5), nn.Linear(512, 10) )def stn(self, x): x2 = self.location_cov(x) x2 = x2.view(x2.size()[0], -1) x2 = self.localization_linear(x2) theta = x2.view(x2.size()[0], 2, 3) grid = nn.functional.affine_grid(theta, x.size(), align_corners=True) x = nn.functional.grid_sample(x, grid, align_corners=True) return xdef forward(self, x): x = self.stn(x) x = self.cnn(x) x = x.view(x.size()[0], -1) x = self.linear(x) return xif __name__ == '__main__': x = torch.rand(1, 1, 28, 28) model = STN() print(model) print(model(x))

train
import numpy as np import torch from torchvision import transforms import torch.utils.data import matplotlib.pyplot as plt import torchvision from torch.utils.tensorboard import SummaryWriter from torchvision.datasets import ImageFolder from PIL import Image from torch import nn, optimfrom stn.CNN import CNN from stn.STN import STNdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 数据处理 transform = transforms.Compose([ transforms.RandomRotation(45), transforms.ToTensor(), transforms.Normalize((0.5), (0.5)) ] )train_data = https://www.it610.com/article/torchvision.datasets.MNIST('../data/mnist', download=True, train=True, transform=transform )test_data = https://www.it610.com/article/torchvision.datasets.MNIST('../data/mnist', download=True, train=False, transform=transform, )train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)data_iter = iter(train_loader) imgs = torchvision.utils.make_grid(next(data_iter)[0], 8) imgs = imgs.numpy().transpose(1, 2, 0) imgs = imgs * 0.5 + 0.5 plt.imshow(imgs) plt.show()# model = CNN() model = STN() model = model.to(device) loss_fun = nn.CrossEntropyLoss().to(device) opt_fun = optim.Adam(params=model.parameters(), lr=0.001)loss = 0 train_acc_count = [] test_acc_count = [] train_loss = [] test_loss = []def train(epoch):for i in range(epoch): for index, data in enumerate(train_loader): imgs = data[0].to(device) labels = data[1].to(device) outputs = model(imgs).to(device) loss = loss_fun(outputs, labels) loss.backward() opt_fun.step() opt_fun.zero_grad() if index % 100 == 0: print("第{}轮,第{}次,loss为:{}".format(i + 1, index, loss.item())) train_loss.append(loss.item())def test(): test_count = 0. for imgs, labels in test_loader: with torch.no_grad(): outputs = model(imgs.to(device)).to(device) test_acc_count = (torch.max(outputs, dim=1)[1] == labels.to(device)).sum().item() test_count = labels.size()[0] print("测试集准确率{}".format(test_acc_count / test_count))if __name__ == '__main__': # 设置随机数种子 np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed_all(1) # 保证每次结果一样 torch.backends.cudnn.deterministic = True train(10) test() sava_path = '../model/mnistStn.pth' torch.save(model.state_dict(), sava_path) plt.plot(train_loss) plt.show()

showImage
from torchvision import datasets, transforms from torch.utils.data import DataLoader import torchvision import torch import matplotlib.pyplot as pltfrom stn.STN import STNtransform = transforms.Compose([ transforms.RandomRotation(45), transforms.ToTensor(), transforms.Normalize((0.5), (0.5)) ] )train_data = https://www.it610.com/article/torchvision.datasets.MNIST('../data/mnist', download=True, train=True, transform=transform )train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)data_iter = iter(train_loader) imgs, labels = next(data_iter) pre = torchvision.utils.make_grid(imgs, 8) pre = pre.numpy().transpose(1, 2, 0) pre = pre * 0.5 + 0.5 plt.subplot(2, 1, 1) plt.imshow(pre) plt.title('pre')model = STN() model.load_state_dict(torch.load('../model/mnistStn.pth')) now = model.stn(imgs).detach() now = torchvision.utils.make_grid(now, 8) now = now.numpy().transpose(1, 2, 0) now = now * 0.5 + 0.5 plt.subplot(2, 1, 2) plt.imshow(now) plt.title('now')plt.show()

train,epoch=10
python|stn在mnist上的实现
文章图片

? 展示transom后的图片,还是感觉很神奇
python|stn在mnist上的实现
文章图片

    推荐阅读