pytorch notes
z

数据集加载流程:

  1. torchvision.datasets中已经包含众多数据集,先建立一个数据集对。

    1
    mnist_train = torchvision.datasets.MNIST(root="\data", train=True, download=True, transform= transform)

    不过在建立对象的时候,需要指定transform来确定对原始数据集中的图片的转换过程。

    1
    transform = torchvision.transforms.ToTensor() # 定义transform为ToTensor(),将原始PLT数据转换为Tensor

    又例如:

    1
    2
    3
    transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 定义transform为这一系列步骤,先将图片转换为tensor,在将其进行归一化,第一个括号代表将RGB三通道的数据均值都归一化到0.5,第二个括号代表将均方差同样也归一化到0.5
  2. 使用torch.utils.data中的DataLoader来加载之前定义好的数据集对象

    1
    train_loader = DataLoader(mnist_train, batch_size=8, shuffle=True, num_workers=4)
  3. 读取DataLoader中的数据,一般在训练的时候,使用

    1
    for batch, (img, label) in enumerate(train_loader):

    在测试时,使用

    1
    for img, label in test_loader:

图片显示:

 需要导入`matploitlab`来帮助显示图片

直接通过```for img, label in train_loader``` 得到的`img`的维度是四维的 [ $ batch\_size  \times channels \times width \times hight$ ], 例如,从`Mnist`中读取的数据维度为: `torch.Size([8, 1, 28, 28])` 表示这个batch中有八张图片,由于MNIST中的图片为灰度图,因此`channel`为1。

要显示单张图片,
1
2
img = img[0] # 取出第一张图片
plt.imshow(img[0], cmap="gray") # 灰度图像显示时,接受的图片维度为(width,hight),因此这里再次取[0],使(channel, width, hight)降为(width, hight)
要显示单张图片:
1
2
img = torchvision.utils.make_grid(img).numpy()
plt.imshow(np.transpose(img, (1, 2, 0))) #这里的(1, 2, 0) 指的是,将原始维度(channel, width, hight)转换为 (width, hight, channel)