AI|PyTorch实现LeNet-5

【AI|PyTorch实现LeNet-5】
目录
建立模型 lenet5
训练模型
测试模型
下面文章写得很好,代码也写得很清晰。
PyTorch实现经典网络之LeNet5 - 简书 (jianshu.com)
数据集:MNIST in CSV | Kaggle
可能是环境不一样,修改了两处代码:
1. train 41行:

loss_list.append(loss.item())

2. test 32行:
plt.plot(testList, accuracy_list, "r-", label="Test")

源代码:
建立模型 lenet5
import torch.nn as nnclass LeNet5(nn.Module): def __init__(self): super(LeNet5, self).__init__() # 包含一个卷积层和池化层,分别对应LeNet5中的C1和S2, # 卷积层的输入通道为1,输出通道为6,设置卷积核大小5x5,步长为1 # 池化层的kernel大小为2x2 self._conv1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1), nn.MaxPool2d(kernel_size=2) ) # 包含一个卷积层和池化层,分别对应LeNet5中的C3和S4, # 卷积层的输入通道为6,输出通道为16,设置卷积核大小5x5,步长为1 # 池化层的kernel大小为2x2 self._conv2 = nn.Sequential( nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1), nn.MaxPool2d(kernel_size=2) ) # 对应LeNet5中C5卷积层,由于它跟全连接层类似,所以这里使用了nn.Linear模块 # 卷积层的输入通特征为4x4x16,输出特征为120x1 self._fc1 = nn.Sequential( nn.Linear(in_features=4 * 4 * 16, out_features=120) ) # 对应LeNet5中的F6,输入是120维向量,输出是84维向量 self._fc2 = nn.Sequential( nn.Linear(in_features=120, out_features=84) ) # 对应LeNet5中的输出层,输入是84维向量,输出是10维向量 self._fc3 = nn.Sequential( nn.Linear(in_features=84, out_features=10) )def forward(self, input): # 前向传播 # MNIST DataSet image's format is 28x28x1 # [28,28,1]--->[24,24,6]--->[12,12,6] conv1_output = self._conv1(input) # [12,12,6]--->[8,8,,16]--->[4,4,16] conv2_output = self._conv2(conv1_output) # 将[n,4,4,16]维度转化为[n,4*4*16] conv2_output = conv2_output.view(-1, 4 * 4 * 16) # [n,256]--->[n,120] fc1_output = self._fc1(conv2_output) # [n,120]-->[n,84] fc2_output = self._fc2(fc1_output) # [n,84]-->[n,10] fc3_output = self._fc3(fc2_output) return fc3_output

训练模型
import torch import torch.nn as nn import torch.optim as optimimport pandas as pd import matplotlib.pyplot as plt from AI.CNN.LeNet_5.lenet5 import LeNet5train_data = https://www.it610.com/article/pd.DataFrame(pd.read_csv("Data/mnist_train.csv"))model = LeNet5() print(model)# 定义交叉熵损失函数 loss_fc = nn.CrossEntropyLoss() # 用model的参数初始化一个随机梯度下降优化器 optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.78) loss_list = [] x = []# 迭代次数1000次 for i in range(1000): # 小批量数据集大小设置为30 batch_data = https://www.it610.com/article/train_data.sample(n=30, replace=False) # 每一条数据的第一个值是标签数据 batch_y = torch.from_numpy(batch_data.iloc[:, 0].values).long() # 图片信息,一条数据784维将其转化为通道数为1,大小28*28的图片。 batch_x = torch.from_numpy(batch_data.iloc[:, 1::].values).float().view(-1, 1, 28, 28)# 前向传播计算输出结果 prediction = model.forward(batch_x) # 计算损失值 loss = loss_fc(prediction, batch_y) # Clears the gradients of all optimized optimizer.zero_grad() # back propagation algorithm loss.backward() # Performs a single optimization step (parameter update). optimizer.step() print("第%d次训练,loss为%.3f" % (i, loss.item())) loss_list.append(loss.item()) x.append(i)# Saves an object to a disk file. torch.save(model.state_dict(), "TrainedModel/LeNet5.pkl") print('Networks''s keys: ', model.state_dict().keys()) print(x) print(loss_list)plt.figure() plt.xlabel("number of epochs") plt.ylabel("loss") plt.plot(x, loss_list, "r-") plt.show()

测试模型
import torch import numpy as np import pandas as pd import matplotlib.pyplot as plt from AI.CNN.LeNet_5.lenet5 import LeNet5model = LeNet5() test_data = https://www.it610.com/article/pd.DataFrame(pd.read_csv("Data/mnist_test.csv")) # Load model parameters model.load_state_dict(torch.load("TrainedModel/LeNet5.pkl"))accuracy_list = [] testList = []with torch.no_grad(): # 进行一百次测试 for i in range(100): # 每次从测试集中随机挑选50个样本 batch_data = https://www.it610.com/article/test_data.sample(n=50, replace=False) batch_x = torch.from_numpy(batch_data.iloc[:, 1::].values).float().view(-1, 1, 28, 28) batch_y = batch_data.iloc[:, 0].values prediction = np.argmax(model(batch_x).numpy(), axis=1) acccurcy = np.mean(prediction == batch_y) print("第%d组测试集,准确率为%.3f" % (i, acccurcy)) accuracy_list.append(acccurcy) testList.append(i)plt.figure() plt.xlabel("number of tests") plt.ylabel("accuracy rate") plt.ylim(0, 1) plt.plot(testList, accuracy_list, "r-", label="Test") plt.legend() plt.show()


    推荐阅读