04_pytorch_build_the_neural_network
BUILD THE NEURAL NETWORK
神经网络的构成是什么?
神经网络由对数据执行操作的层/模块组成。
层和模块在哪里?
torch.nn 命名空间提供了构建您自己的神经网络所需的所有构建块。
PyTorch 中的每个模块都是 nn.Module 的子类。
神经网络的的嵌套构成
神经网络本身就是一个模块,由其他模块(层)组成。
嵌套结构的好处
这种嵌套结构允许轻松构建和管理复杂的架构。
在以下部分中,我们将构建一个神经网络来对 FashionMNIST 数据集中的图像进行分类。
1 | import os |
Get Device for Training
如果我们在具有GPU的设备上,可以使用GPU加速。
如何使用GPU?
我们希望能够在 GPU 或 MPS 等硬件加速器(如果可用)上训练我们的模型。让我们检查一下 torch.cuda 或者 torch.backends.mps是否可用,否则我们使用 CPU。
1 | device = ( |
Define the Class
依照我们之前的说法,神经网络本身就是一个module,所以我们需要继承nn.Module。
我们通过子类化 nn.Module
来定义神经网络,并初始化 __init__
中的神经网络层。每个 nn.Module
子类都实现 forward
方法中对输入数据的操作。
1 | class NeuralNetwork(nn.Module): |
我们创建 NeuralNetwork
的实例,并将其移动到 device
,并打印其结构。
1 | model = NeuralNetwork().to(device) |
为了使用该模型,我们将输入数据传递给它。这将执行模型的 forward
以及一些background operations。不要直接调用 model.forward()
!
在输入上调用模型会返回一个二维张量,其中 dim=0 对应于每个类的 10 个原始预测值的每个输出,dim=1 对应于每个输出的各个值。我们通过将预测概率传递给 nn.Softmax
模块的实例来获取预测概率。
1 | X = torch.rand(1, 28, 28, device=device) |
Model Layers
让我们分解 FashionMNIST 模型中的各个层。为了说明这一点,我们将采用 3 张大小为 28x28 的图像的小批量样本,看看当我们将其传递到网络时会发生什么。
1 | input_image = torch.rand(3,28,28) |
nn.Flatten
我们初始化 nn.Flatten 层,将每个 2D 28x28 图像转换为 784 个像素值的连续数组(维持小批量维度(在 dim=0 时))。
1 | flatten = nn.Flatten() |
1 | torch.Size([3, 784]) |
nn.Linear
linear layer是一个使用其存储的权重和偏差对输入变量线性变换的模块。
1 | layer1 = nn.Linear(in_features=28*28, out_features=20) |
nn.ReLU
非线性激活在模型的输入和输出之间创建复杂的映射。它们在线性变换后应用以引入非线性,帮助神经网络学习各种现象。
在此模型中,我们在线性层之间使用 nn.ReLU,但还有其他激活可以在模型中引入非线性。
1 | print(f"Before ReLU: {hidden1}\n\n") |
1 | Before ReLU: tensor([[ 0.4158, -0.0130, -0.1144, 0.3960, 0.1476, -0.0690, -0.0269, 0.2690, |
nn.Sequential
nn.Sequential 是模块的有序容器。数据按照定义的相同顺序传递通过所有模块。您可以使用顺序容器来组合一个快速网络,例如 seq_modules
.
1 | seq_modules = nn.Sequential( |
nn.Softmax
神经网络的最后一个线性层返回 logits - [-infty, infty] 中的原始值 - 被传递到 nn.Softmax 模块。 Logits 缩放为值 [0, 1],表示模型对每个类别的预测概率。 dim
参数指示维度,沿该维度值的总和必须为 1。
1 | softmax = nn.Softmax(dim=1) |
Model Parameters
神经网络内的许多层都是参数化的,即具有在训练期间优化的相关权重和偏差。子类化 nn.Module
自动跟踪模型对象内定义的所有字段,并使所有参数可使用模型的 parameters()
或 named_parameters()
方法访问。
在此示例中,我们迭代每个参数,并打印其大小及其值的预览。
1 | print(f"Model structure: {model}\n\n") |