PyTorch实例:如何打印神经网络结构

分类:知识百科 日期: 点击:0

使用PyTorch可以很容易地打印神经网络的结构。具体的实现方法如下:

1. 建立神经网络模型

# 建立一个简单的网络
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 30)
        self.fc3 = nn.Linear(30, 10)

net = Net()

2. 使用print_summary()函数打印网络结构

# 导入打印神经网络结构模块
from torchsummary import summary

# 使用summary函数打印神经网络结构
summary(net, (10,))

3. 输出结果

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                  [-1, 20]              220
            Linear-2                  [-1, 30]              630
            Linear-3                  [-1, 10]              310
================================================================
Total params: 1,160
Trainable params: 1,160
Non-trainable params: 0
----------------------------------------------------------------

从输出结果可以看出,网络结构由三层线性层组成,每层的输入输出形状,以及参数数量都能够被打印出来。

标签:

版权声明

1. 本站所有素材,仅限学习交流,仅展示部分内容,如需查看完整内容,请下载原文件。
2. 会员在本站下载的所有素材,只拥有使用权,著作权归原作者所有。
3. 所有素材,未经合法授权,请勿用于商业用途,会员不得以任何形式发布、传播、复制、转售该素材,否则一律封号处理。
4. 如果素材损害你的权益请联系客服QQ:77594475 处理。