【背景】
Pytorch为我们准备了许多灵活地数据加载方法,也自带了多种数据集(Cifar, MNIST, FASHIONMNIST等)。但有时我们需要自己定义数据集以适应自己的研究/开发需求。
【官方文档】
在torch中,我们可以很简单地定义Datasets,官方文档中描述如下:

Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数(见下code)必须被重载,否则将会触发错误提示.
class DATASETSNAME(torch.utils.data.Dataset) def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError
其中__len__应返回数据集的大小(有时也叫长度),而__getitem__则应实现能够支持数据集索引的函数。
【getitem Function】
从函数的参数上看,__getitem__(self, index)接收一个index,然后返回图片数据和标签,这里的index通常指的是list的索引值,而list中的元素应该包含了图片数据的路径和标签信息。所以从本质上来看,pytorch将我们的数据集当做list来进行维护,我们只需要输入index即可直接从list中获取我们想要的数据。而对于list的维护也是我们应该重点关注的点。在本文实现例子中,VOC2012数据集将所有的图片名称规整地放置于:/VOCdevkit/VOC2012/Imagesets/分任务子文件夹/xxx.txt中。见下图:

打开其中的txt文件,得到如下内容:

打开数据的其他文件夹我们观察到,这些以行为单位的数字代表着一张图片的ID,然后数据集以不同的文件夹对图片进行了分类。如:

其中:
| 文件夹名 | 描述 |
| Annotations | 存储每幅图在目标检测任务中需要的语义信息xml文档 |
| ImageSets | 存储图像的标签信息(用以维护list) |
| JPEGImages | 存储原始图像 |
| SegmentationClass | 存储语义分割任务需要的图像(G.T) |
| SegmentationObject | 存储目标分割需要的图像(G.T) |
| SegmentationAug |
存储增强版语义分割任务训练图像(G.T) |
我们打开JPEGImages看一下数据集的内容:

那么我们的整体思路就有了:
我们通过读取:ImagetSets文件夹下的txt文件获取不同任务需要的数据集图像ID。(这个ID是制作数据集的作者提前写好的),然后我们将这些ID转换成文件存储路径(训练数据,GroundTruth, 训练验证数据),再将路径信息转换成为list,该list中每一个元素对应一个样本。这样我们就可以通过__getitem__(),函数获取到数据图像和标签。
代码如下:
def _list_files(self):
self.image_dir = os.path.join(self.root, "JPEGImages")
self.label_dir = os.path.join(self.root, "SegmentationClass")
if self.split in ['val', 'train', 'train_val', 'test']:
self.file_list = os.path.join(self.root, "ImageSets\Segmentation", self.split+".txt")
self.file_list = tuple(open(self.file_list, "r"))
image_list = [os.path.join(self.image_dir, id.replace("\n", '') + '.jpg') for id in self.file_list]
label_list = [os.path.join(self.label_dir, id.replace("\n", '') + '.png') for id in self.file_list]
self.images = image_list
self.labels = label_list
else:
raise ValueError("Invalid split name:{}".format(self.split))
读取数据图像代码如下:
def _load_datas(self, index):
image_id = self.file_list[index].replace('\n', '')
image_path = self.images[index]
label_path = self.labels[index]
# important!!! The image_path should not contain the charset with chinese location
image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32)
label = np.asarray(Image.open(label_path), dtype=np.int32)
return image_id, image, label
【数据预处理】
我们这里并没有直接在__getitem__函数中构造function,而是将转换list操作单独写入了一个方法中。这是为了方便我们对数据集进行一些预处理操作。(均值化,随机裁剪,填充,随机仿射等)。这是作者在阅读文献时积累的几个代码(基于cv2, numpy, PIL, random),但其实pytorch也为我们提供了庞大的数据预处理的库。详见:torchvision.transforms
随机裁剪代码如下:
def randomCrop(image, label, mean_bgr, crop_size=382):
# Handle some except produres by height or width lower than the size of cropping operation
h, w =label.shape
pad_h = max(crop_size - h, 0)
pad_w = max(crop_size - w, 0)
mean_bgr = np.array(mean_bgr)
pad_kwargs = {
"top": 0,
"bottom": pad_h,
"left": 0,
"right": pad_w,
"borderType": cv2.BORDER_CONSTANT
}
if pad_h > 0 or pad_w > 0:
image = cv2.copyMakeBorder(image, value=mean_bgr, **pad_kwargs)
label = cv2.copyMakeBorder(label, value=0, **pad_kwargs)
# Handle ended
# Random Cropping
h, w = label.shape
start_h = random.randint(0, h - crop_size)
start_w = random.randint(0, w - crop_size)
end_h = start_h + crop_size
end_w = start_w + crop_size
image = image[start_h:end_h, start_w:end_w]
label = label[start_h:end_h, start_w:end_w]
# Cropping ended
return image, label
随机翻转代码如下:
def randomFlip(image, label): if random.random() < 0.5: image = np.fliplr(image).copy() label = np.fliplr(label).copy() return image, label
随机放缩代码如下:
def randomScale(image, label, scales=(0.5, 0.75, 1.0, 1.25, 1.5)): h, w = label.shape scale_factor = random.choice(scales) h = int(h * scale_factor) w = int(w * scale_factor) image = cv2.resize(image, (w,h), interpolation=cv2.INTER_LINEAR) label = Image.fromarray(label).resize((w, h), resample=Image.NEAREST) label = np.asarray(label, dtype=np.int64) return image, label
【完整实现VOC2012】
VOC.py代码如下:
# -*- coding:utf-8 -*-
import torch
from torch.utils import data
from utils.dsetsHandle import randomFlip, randomCrop, randomScale
import cv2
from PIL import Image
import numpy as np
import os
class VOC12Dataset(data.Dataset):
def __init__(self, root_path, split, mean_bgr=(104.008, 116.669, 122.675), augment=0):
self.root = root_path
self.augment = augment
self.split = split
self._list_files()
self.mean_bgr = mean_bgr
self.images
self.labels
self._load_datas(0)
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
(image_id, image, label) = self._load_datas(index)
#数据预处理,增强操作
if self.augment == 1:
image, label = randomScale(self.images[index], self.labels[index])
if self.augment == 2:
image, label = randomCrop(self.images[index], self.labels[index], self.mean_bgr)
if self.augment == 3:
image, label = randomFlip(self.images[index], self.labels[index])
image -= self.mean_bgr #均值化处理
#image = image.transpose(2, 0, 1)
return image_id, image.astype(np.float32), label.astype(np.int64)
def _list_files(self):
self.image_dir = os.path.join(self.root, "JPEGImages")
self.label_dir = os.path.join(self.root, "SegmentationClass")
if self.split in ['val', 'train', 'train_val', 'test']:
self.file_list = os.path.join(self.root, "ImageSets\Segmentation", self.split+".txt")
self.file_list = tuple(open(self.file_list, "r"))
image_list = [os.path.join(self.image_dir, id.replace("\n", '') + '.jpg') for id in self.file_list]
label_list = [os.path.join(self.label_dir, id.replace("\n", '') + '.png') for id in self.file_list]
self.images = image_list
self.labels = label_list
else:
raise ValueError("Invalid split name:{}".format(self.split))
def _load_datas(self, index):
image_id = self.file_list[index].replace('\n', '')
image_path = self.images[index]
label_path = self.labels[index]
# important!!! The image_path should not contain the charset with chinese location
image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32)
label = np.asarray(Image.open(label_path), dtype=np.int32)
return image_id, image, label
Utils.dsetsHandle.py
# -*- coding:utf-8 -*-
import cv2
import numpy as np
from PIL import Image
import random
def randomScale(image, label, scales=(0.5, 0.75, 1.0, 1.25, 1.5)):
h, w = label.shape
scale_factor = random.choice(scales)
h = int(h * scale_factor)
w = int(w * scale_factor)
image = cv2.resize(image, (w,h), interpolation=cv2.INTER_LINEAR)
label = Image.fromarray(label).resize((w, h), resample=Image.NEAREST)
label = np.asarray(label, dtype=np.int64)
return image, label
def randomCrop(image, label, mean_bgr, crop_size=382):
# Handle some except produres by height or width lower than the size of cropping operation
h, w =label.shape
pad_h = max(crop_size - h, 0)
pad_w = max(crop_size - w, 0)
mean_bgr = np.array(mean_bgr)
pad_kwargs = {
"top": 0,
"bottom": pad_h,
"left": 0,
"right": pad_w,
"borderType": cv2.BORDER_CONSTANT
}
if pad_h > 0 or pad_w > 0:
image = cv2.copyMakeBorder(image, value=mean_bgr, **pad_kwargs)
label = cv2.copyMakeBorder(label, value=0, **pad_kwargs)
# Handle ended
# Random Cropping
h, w = label.shape
start_h = random.randint(0, h - crop_size)
start_w = random.randint(0, w - crop_size)
end_h = start_h + crop_size
end_w = start_w + crop_size
image = image[start_h:end_h, start_w:end_w]
label = label[start_h:end_h, start_w:end_w]
# Cropping ended
return image, label
def randomFlip(image, label):
if random.random() < 0.5:
image = np.fliplr(image).copy()
label = np.fliplr(label).copy()
return image, label
def BGR2RGB(image):
image = image.copy()
temp = image[:,:, 0].copy()
image[:, :, 0] = image[:, :, 2].copy()
image[:, :, 2] = temp
return image
【测试与结果】
测试代码如下:
# -*- coding:utf-8 -*-
from datasets.VOC import VOC12Dataset
import matplotlib.pyplot as plt
import cv2
a = VOC12Dataset(root_path='E:\VOCtrainval_11-May-2012\VOCdevkit\VOC2012', split='train')
plt.imshow(a[0][2])
plt.show()
cv2.imshow('a', a[0][1])
cv2.waitKey(0)
cv2.destroyAllWindows()
结果展示见下图:

其中:左图为原始图像,中间图为经过数据集加载后得到的图像,右图为pyplot输出的Ground-truth图像。