pytorch notes

数据集加载流程:
torchvision.datasets
中已经包含众多数据集,先建立一个数据集对。1
mnist_train = torchvision.datasets.MNIST(root="\data", train=True, download=True, transform= transform)
不过在建立对象的时候,需要指定
transform
来确定对原始数据集中的图片的转换过程。1
transform = torchvision.transforms.ToTensor() # 定义transform为ToTensor(),将原始PLT数据转换为Tensor
又例如:
1
2
3transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 定义transform为这一系列步骤,先将图片转换为tensor,在将其进行归一化,第一个括号代表将RGB三通道的数据均值都归一化到0.5,第二个括号代表将均方差同样也归一化到0.5使用
torch.utils.data
中的DataLoader
来加载之前定义好的数据集对象1
train_loader = DataLoader(mnist_train, batch_size=8, shuffle=True, num_workers=4)
读取DataLoader中的数据,一般在训练的时候,使用
1
for batch, (img, label) in enumerate(train_loader):
在测试时,使用
1
for img, label in test_loader:
图片显示:
需要导入`matploitlab`来帮助显示图片
直接通过```for img, label in train_loader``` 得到的`img`的维度是四维的 [ $ batch\_size \times channels \times width \times hight$ ], 例如,从`Mnist`中读取的数据维度为: `torch.Size([8, 1, 28, 28])` 表示这个batch中有八张图片,由于MNIST中的图片为灰度图,因此`channel`为1。
要显示单张图片,
1 | img = img[0] # 取出第一张图片 |
要显示单张图片:
1 | img = torchvision.utils.make_grid(img).numpy() |