pytorch torch.utils.data 源码分析

dataset, datasetloader, sampler

Dataset

所有表示从 key 到数据样本的映射的数据集都应该继承它。所有子类都应该覆盖 ‘ __getitem__ ‘方法,支持获取给定key的数据样本。子类也可以选择性地覆盖: ‘ __len__ ‘方法,许多类期望获得数据集的大小, 如: ‘ torch.utils.data.Sampler的实现和默认选项类’ torch.utils.data.DataLoader ‘ 。

class Dataset(object):

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

ConcatDataset 是Dataset的子类,按理来讲,如果Dataset作为一个接口或者说抽象类来说,是不会有__add__这个实现的,耦合,trick太多了。实现了这个__add__之后,就可以用加号来使用。

子类

IterableDataset

就像是名字一样,是一个 iterable dataset,多了一个__iter__用以实现,已经不像是一个key到sample的映射,而是一个可以直接迭代的数据集。实现后已经类似于Sampler子类了。

如下面这段源码,DateLoader类中加载,如果dataset的类别是Iterable的时候

    def __iter__(self):
        raise NotImplementedError
ChainDataset(IterableDataset)

ChainDataset的datasets必须都是IterableDataset的子类

    def __iter__(self):
        for d in self.datasets:
            assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
            for x in d:
                yield x

TensorDataset

每个样本都将通过索引第一个维度上的张量来检索。

ConcatDataset

拼接多个Dateset, 相当于按顺序给后面的数据集使用累加的索引,通过累加的索引访问所有的数据集,虽然是通过类似于二维数组实现的(就是)

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx]

Subset

子数据集,利用一段索引来访问数据集

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

Sampler

迭代Dataset的索引,如: Dataset length = 5 则SequentialSampler会给出iter([0,1,2,3,4]) ,Sampler的__iter__()会给出一个list的iterator,即为数据集indices,可以通过这些indices 与 DatasetFetcher来获取样本数据。

class Sampler(object):
    r"""Base class for all Samplers.

    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.

    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """
    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

子类

SequentialSampler

    def __iter__(self):
        return iter(range(len(self.data_source)))

RandomSampler

给出一个随机的索引List 的 iterator

    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(torch.randperm(n).tolist())

SubsetRandomSampler

在给定的一段索引列表中随机,相当于shuffle一段子数据集的索引

    def __iter__(self):
        return (self.indices[i] for i in torch.randperm(len(self.indices)))

WeightedRandomSampler

multinomial,通过weights 用来作为下标的权重,最后返回的是weights的索引(又因为有不同的权重,所以会优先选择权重较大的下标)

    def __iter__(self):
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

BatchSampler

初始化输入时sampler,也就是说这个类其实是对一个 sampler 迭代,每一次迭代batch_size次;drop_last是决定是否舍弃最后不足 batch_size 大小的一个批次,如果为False,也就是不舍弃,那么最后一个批次较小。

def __init__(self, sampler, batch_size, drop_last)
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

DataLoader 与 _BaseDataLoaderIter

重点来了

前置内容

_DatasetKind DatasetDataLoader _BaseDataLoaderIter_BaseDatasetFetcherSampler

先介绍一下_DatasetKind,类似于java的枚举类,定义了两个常量用以标注Dataset的类型,如果 isinstance(dataset, IterableDataset),即dataset是用 IterableDataset 实现的话,就是Iterable 类型的,其余的都是Map类型。

Map = 0
Iterable = 1

_DatasetKind是为在Fetch的时候选择不同的_BaseDatasetFetcher子类,_MapDatasetFetcher 和_IterableDatasetFetcher,用以获得样本的数据(如果有cudn会写入到gpu的内存中)

_BaseDataLoaderIter 有两个子类_SingleProcessDataLoaderIter 和_MultiProcessingDataLoaderIter,选择哪个子类是 num_workers 决定的 ,也就是是否多线程加载样本数据。

还有Python iterator的一些东西

理解这张图,先看一下前面的前置内容。 上面的图简化的代码就是下面4行代码,当然只是伪代码。

dataset = Dataset()
dataloader = Dataloader(dataset,...)
iterator = iter(dataloader)
data = next(iterator)

dataset = Dataset() 获得数据集的实例,dataloader = Dataloader(dataset,…) 获得 dataloader 的实例,通过 dataloader 的__iter__()获得iterator也就是_BaseDataLoaderIter ,通过Simpler获取indices ,在将indices输入到DatasetFetcher中获取样本数据。这就是一整个流程。





除非注明,否则均为一叶呼呼原创文章,转载必须以链接形式标明本文链接

本文链接:http://www.yiyehu.tech/archives/2020/04/15/pytorch-torch-utils-data-sourcecode

发表评论

电子邮件地址不会被公开。 必填项已用*标注