机器学习|pytorch查看网络架构的几种方法

一、Print(model)

import torch from torch import nnnet = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1)) print(net)

outpu:
Sequential(
(0): Linear(in_features=4, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=1, bias=True)
)
二、torchsummary
import torch from torch import nnnet = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))from torchsummary import summary print(summary(net,input_size=(2,4)))

【机器学习|pytorch查看网络架构的几种方法】----------------------------------------------------------------
Layer (type) Output Shape Param #
==============================================================
Linear-1 [-1, 2, 8] 40
ReLU-2 [-1, 2, 8] 0
Linear-3 [-1, 2, 1] 9
==============================================================
Total params: 49
Trainable params: 49
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------

    推荐阅读