Warning: Undefined array key "HTTP_ACCEPT_LANGUAGE" in /www/wwwroot/blog/wp-content/plugins/UEditor-KityFormula-for-wordpress/main.php on line 13
【Pytorch基础】自定义数据集 – Machine World

【背景】

    Pytorch为我们准备了许多灵活地数据加载方法,也自带了多种数据集(Cifar, MNIST, FASHIONMNIST等)。但有时我们需要自己定义数据集以适应自己的研究/开发需求。

【官方文档】

在torch中,我们可以很简单地定义Datasets,官方文档中描述如下:

image.png

    Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数(见下code)必须被重载,否则将会触发错误提示.

class DATASETSNAME(torch.utils.data.Dataset)
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

其中__len__应返回数据集的大小(有时也叫长度),而__getitem__则应实现能够支持数据集索引的函数。

【getitem Function】

    从函数的参数上看,__getitem__(self, index)接收一个index,然后返回图片数据和标签,这里的index通常指的是list的索引值,而list中的元素应该包含了图片数据的路径和标签信息。所以从本质上来看,pytorch将我们的数据集当做list来进行维护,我们只需要输入index即可直接从list中获取我们想要的数据。而对于list的维护也是我们应该重点关注的点。在本文实现例子中,VOC2012数据集将所有的图片名称规整地放置于:/VOCdevkit/VOC2012/Imagesets/分任务子文件夹/xxx.txt中。见下图:

image.png

打开其中的txt文件,得到如下内容:

image.png

打开数据的其他文件夹我们观察到,这些以行为单位的数字代表着一张图片的ID,然后数据集以不同的文件夹对图片进行了分类。如:

image.png

其中:

文件夹名 描述
Annotations 存储每幅图在目标检测任务中需要的语义信息xml文档
ImageSets 存储图像的标签信息(用以维护list)
JPEGImages 存储原始图像
SegmentationClass 存储语义分割任务需要的图像(G.T)
SegmentationObject 存储目标分割需要的图像(G.T)
SegmentationAug

存储增强版语义分割任务训练图像(G.T)

我们打开JPEGImages看一下数据集的内容:

image.png

那么我们的整体思路就有了:

我们通过读取:ImagetSets文件夹下的txt文件获取不同任务需要的数据集图像ID。(这个ID是制作数据集的作者提前写好的),然后我们将这些ID转换成文件存储路径(训练数据,GroundTruth, 训练验证数据),再将路径信息转换成为list,该list中每一个元素对应一个样本。这样我们就可以通过__getitem__(),函数获取到数据图像和标签。

代码如下:

def _list_files(self):
    self.image_dir = os.path.join(self.root, "JPEGImages")
    self.label_dir = os.path.join(self.root, "SegmentationClass")
    if self.split in ['val', 'train', 'train_val', 'test']:
        self.file_list = os.path.join(self.root, "ImageSets\Segmentation", self.split+".txt")
        self.file_list = tuple(open(self.file_list, "r"))
        image_list = [os.path.join(self.image_dir, id.replace("\n", '') + '.jpg') for id in self.file_list]
        label_list = [os.path.join(self.label_dir, id.replace("\n", '') + '.png') for id in self.file_list]
        self.images = image_list
        self.labels = label_list
    else:
        raise ValueError("Invalid split name:{}".format(self.split))

读取数据图像代码如下:

def _load_datas(self, index):
    image_id = self.file_list[index].replace('\n', '')
    image_path = self.images[index]
    label_path = self.labels[index]
    # important!!! The image_path should not contain the charset with chinese location
    image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32)
    label = np.asarray(Image.open(label_path), dtype=np.int32)
    return image_id, image, label

【数据预处理】

    我们这里并没有直接在__getitem__函数中构造function,而是将转换list操作单独写入了一个方法中。这是为了方便我们对数据集进行一些预处理操作。(均值化,随机裁剪,填充,随机仿射等)。这是作者在阅读文献时积累的几个代码(基于cv2, numpy, PIL, random),但其实pytorch也为我们提供了庞大的数据预处理的库。详见:torchvision.transforms

随机裁剪代码如下:

def randomCrop(image, label, mean_bgr, crop_size=382):
    # Handle some except produres by height or width lower than the size of cropping operation
    h, w =label.shape
    pad_h = max(crop_size - h, 0)
    pad_w = max(crop_size - w, 0)
    mean_bgr = np.array(mean_bgr)
    pad_kwargs = {
        "top": 0,
        "bottom": pad_h,
        "left": 0,
        "right": pad_w,
        "borderType": cv2.BORDER_CONSTANT
    }
    if pad_h > 0 or pad_w > 0:
        image = cv2.copyMakeBorder(image, value=mean_bgr, **pad_kwargs)
        label = cv2.copyMakeBorder(label, value=0, **pad_kwargs)
    # Handle ended
    # Random Cropping
    h, w = label.shape
    start_h = random.randint(0, h - crop_size)
    start_w = random.randint(0, w - crop_size)
    end_h = start_h + crop_size
    end_w = start_w + crop_size
    image = image[start_h:end_h, start_w:end_w]
    label = label[start_h:end_h, start_w:end_w]
    # Cropping ended
    return image, label

随机翻转代码如下:

def randomFlip(image, label):
    if random.random() < 0.5:
        image = np.fliplr(image).copy()
        label = np.fliplr(label).copy()
    return image, label

随机放缩代码如下:

def randomScale(image, label, scales=(0.5, 0.75, 1.0, 1.25, 1.5)):
    h, w = label.shape
    scale_factor = random.choice(scales)
    h = int(h * scale_factor)
    w = int(w * scale_factor)
    image = cv2.resize(image, (w,h), interpolation=cv2.INTER_LINEAR)
    label = Image.fromarray(label).resize((w, h), resample=Image.NEAREST)
    label = np.asarray(label, dtype=np.int64)
    return image, label

【完整实现VOC2012】

VOC.py代码如下:

# -*- coding:utf-8 -*-

import torch
from torch.utils import data
from utils.dsetsHandle import randomFlip, randomCrop, randomScale
import cv2
from PIL import Image
import numpy as np
import os
class VOC12Dataset(data.Dataset):
    def __init__(self, root_path, split, mean_bgr=(104.008, 116.669, 122.675), augment=0):
        self.root = root_path
        self.augment = augment
        self.split = split
        self._list_files()
        self.mean_bgr = mean_bgr
        self.images
        self.labels
        self._load_datas(0)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        (image_id, image, label) = self._load_datas(index)
        #数据预处理,增强操作
        if self.augment == 1:
            image, label = randomScale(self.images[index], self.labels[index])
        if self.augment == 2:
            image, label = randomCrop(self.images[index], self.labels[index], self.mean_bgr)
        if self.augment == 3:
            image, label = randomFlip(self.images[index], self.labels[index])
        image -= self.mean_bgr #均值化处理

        #image = image.transpose(2, 0, 1)

        return image_id, image.astype(np.float32), label.astype(np.int64)

    def _list_files(self):
        self.image_dir = os.path.join(self.root, "JPEGImages")
        self.label_dir = os.path.join(self.root, "SegmentationClass")
        if self.split in ['val', 'train', 'train_val', 'test']:
            self.file_list = os.path.join(self.root, "ImageSets\Segmentation", self.split+".txt")
            self.file_list = tuple(open(self.file_list, "r"))
            image_list = [os.path.join(self.image_dir, id.replace("\n", '') + '.jpg') for id in self.file_list]
            label_list = [os.path.join(self.label_dir, id.replace("\n", '') + '.png') for id in self.file_list]
            self.images = image_list
            self.labels = label_list
        else:
            raise ValueError("Invalid split name:{}".format(self.split))

    def _load_datas(self, index):
        image_id = self.file_list[index].replace('\n', '')
        image_path = self.images[index]
        label_path = self.labels[index]
        # important!!! The image_path should not contain the charset with chinese location
        image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32)
        label = np.asarray(Image.open(label_path), dtype=np.int32)
        return image_id, image, label

Utils.dsetsHandle.py

# -*- coding:utf-8 -*-
import cv2
import numpy as np
from PIL import Image
import random

def randomScale(image, label, scales=(0.5, 0.75, 1.0, 1.25, 1.5)):
    h, w = label.shape
    scale_factor = random.choice(scales)
    h = int(h * scale_factor)
    w = int(w * scale_factor)
    image = cv2.resize(image, (w,h), interpolation=cv2.INTER_LINEAR)
    label = Image.fromarray(label).resize((w, h), resample=Image.NEAREST)
    label = np.asarray(label, dtype=np.int64)
    return image, label

def randomCrop(image, label, mean_bgr, crop_size=382):
    # Handle some except produres by height or width lower than the size of cropping operation
    h, w =label.shape
    pad_h = max(crop_size - h, 0)
    pad_w = max(crop_size - w, 0)
    mean_bgr = np.array(mean_bgr)
    pad_kwargs = {
        "top": 0,
        "bottom": pad_h,
        "left": 0,
        "right": pad_w,
        "borderType": cv2.BORDER_CONSTANT
    }
    if pad_h > 0 or pad_w > 0:
        image = cv2.copyMakeBorder(image, value=mean_bgr, **pad_kwargs)
        label = cv2.copyMakeBorder(label, value=0, **pad_kwargs)
    # Handle ended
    # Random Cropping
    h, w = label.shape
    start_h = random.randint(0, h - crop_size)
    start_w = random.randint(0, w - crop_size)
    end_h = start_h + crop_size
    end_w = start_w + crop_size
    image = image[start_h:end_h, start_w:end_w]
    label = label[start_h:end_h, start_w:end_w]
    # Cropping ended
    return image, label

def randomFlip(image, label):
    if random.random() < 0.5:
        image = np.fliplr(image).copy()
        label = np.fliplr(label).copy()
    return image, label

def BGR2RGB(image):
    image = image.copy()
    temp = image[:,:, 0].copy()
    image[:, :, 0] = image[:, :, 2].copy()
    image[:, :, 2] = temp
    return image

【测试与结果】

测试代码如下:

# -*- coding:utf-8 -*-

from datasets.VOC import VOC12Dataset
import matplotlib.pyplot as plt
import cv2
a = VOC12Dataset(root_path='E:\VOCtrainval_11-May-2012\VOCdevkit\VOC2012', split='train')
plt.imshow(a[0][2])
plt.show()
cv2.imshow('a', a[0][1])
cv2.waitKey(0)
cv2.destroyAllWindows()

结果展示见下图:

image.png

其中:左图为原始图像,中间图为经过数据集加载后得到的图像,右图为pyplot输出的Ground-truth图像。

【参考文献】

作者 WellLee

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注