Jarvis' Blog (总有美丽的风景让人流连) 总有美丽的风景让人流连

Tensorflow 自定义数据集 (Tensorflow-Datasets Customization, TFDS)

2020-11-25
Jarvis
Post

Tensorflow 提供了一个 tensorflow-datasets 的 Python 库来方便地下载、加载和管理数据集 (下文把 Tensorflow-Datasets 缩写为 TFDS).

提示

由于在中国大陆范围内谷歌服务不可用, 该 API 在中国大陆需要使用代理来下载数据集.

1. 安装

1
pip install tensorflow-datasets

1.1 配置代理

  • 选择任意的代理服务, 假设已经配置了 socks5 代理
  • 安装 privoxy 之类的工具, 用于把 socks5 转换为 http, https 和 ftp.
  • 运行任意涉及 tensorflow_datasetsbuilder.download_and_prepare() 语句的脚本时, 增加如下的环境变量:
1
2
3
export TFDS_HTTP_PROXY=http://127.0.0.1:8118
export TFDS_HTTPS_PROXY=http://127.0.0.1:8118
export TFDS_FTP_PROXY=http://127.0.0.1:8118

注意以上三个都要添加. 如果是希望仅对当前命令有效, 则可以直接添加到命令开头:

1
TFDS_HTTP_PROXY=http://127.0.0.1:8118 TFDS_HTTPS_PROXY=http://127.0.0.1:8118 TFDS_FTP_PROXY=http://127.0.0.1:8118 python demo.py

2. 使用内嵌数据 (在线下载)

TFDS 提供了丰富的内嵌数据集, 包括语音, 图像, 视频, 文本等, 例如知名的 mnist, cifar-10, cifar-100 等等.

对于小型数据集, 我们可以直接使用 API 下载(只会下载一次, 再次调用时会使用缓存的数据集).

提示

在中国大陆的用户可能无法直接从谷歌云存储(Google Cloud Storage, GCS) 下载这些数据集, 需要在运行脚本时按照第1.1节的方法指定代理服务器.

1
2
3
4
5
6
import tensorflow_datasets as tfds

# 指定要加载的数据集名称, split, 和数据下载解压后存放的位置.
datasets = tfds.load('mnist', 
                     split='train', 
                     data_dir='/data/tfds')

当然了也可以使用 builder API 获取数据集的详细信息:

1
2
3
4
5
builder = tfds.builder('mnist', data_dir='/data/tfds')
builder.download_and_prepare()
print(builder.info)

datasets = builder.as_datasets(split='train')

关于内嵌数据更详细的用法, 请参看官方文档.

3. 使用内嵌数据 (手动下载)

纯手动下载数据并使用 TFDS 加载是无法直接调用 API 做到的, 因为 TFDS 内嵌数据集的下载、解压、生成 TFRecord 是一气呵成的, 没有提供 API 来单独完成某一步. 因此这一节只讨论这几个步骤的结果, 从而我们找到数据后复制到其他设备的相应位置即可.

有些服务器在某些情况下可能无法使用代理, 这时候就需要本地下载好数据集后传到服务器上. 我们先需要对 TFDS 数据集的结构做一个分析. 一个完整的 TFDS 数据集包含以下几部分内容:

  • 数据集名称 dataset_name: 如 cifar10
  • 版本 version: 如 3.0.2
  • 数据 data
  • 数据信息 dataset info: 如 dataset_info.json
  • 数据特征 dataset features: 如 image.image.json, label.label.json

下载的数据 TFDS 会自动进行处理并保存为 TFRecord 格式. 假设我们设置的数据文件夹为 /data/tfds, 那么

  • 下载的数据保存在: /data/tfds/downloads (通常是压缩文件)
  • 解压的数据保存在: /data/tfds/downloads/extracted
  • 处理过的数据保存在: /data/tfds/<dataset_name>/<version>

那么拷贝时我们只需要把处理过的数据(即 TFRecord 数据 + dataset_info.json + 其他 *.json)复制到目标设备的数据目录下即可. 然后在使用 TFDS 加载数据时设置不下载数据:

1
2
3
4
datasets = tfds.load('mnist', 
                     split='train', 
                     data_dir='/data/tfds', 
                     download=False)

4. 使用自定义数据集

TFDS 也支持把我们自己的数据集构造为 TFDS 可以读取的格式, 这样就可以使用 TFDS 直接从本地硬盘读取了.

首先根据模板自定义数据集的生成文件:

  • 如果 tensorflow_datasets <= 3.2.0
1
2
3
python /<site-packages>/tensorflow_datasets/scripts/create_new_dataset.py \
  --dataset <dataset_name> \
  --type image   # text, audio, translation
  • 如果 tensorflow_datasets >= 4.0.0
1
2
3
tfds --help

tfds new <dataset_name>

然后会产生下面的文件(或文件夹), 这个例子是用 v3.1.0 的版本生成的, v4.0.0 及以后的版本略有不同, 但可以举一反三.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""my_dataset dataset."""

import tensorflow_datasets.public_api as tfds

_CITATION = """ """

_DESCRIPTION = """ """

class MyDataset(tfds.core.GeneratorBasedBuilder):
  """TODO(my_dataset): Short description of my dataset."""

  # TODO(my_dataset): Set up version.
  VERSION = tfds.core.Version('0.1.0')

  def _info(self):
    # TODO(my_dataset): Specifies the tfds.core.DatasetInfo object
    return tfds.core.DatasetInfo(
        builder=self,
        # This is the description that will appear on the datasets page.
        description=_DESCRIPTION,
        # tfds.features.FeatureConnectors
        features=tfds.features.FeaturesDict({
            # These are the features of your dataset like images, labels ...
        }),
        # If there's a common (input, target) tuple from the features,
        # specify them here. They'll be used if as_supervised=True in
        # builder.as_dataset.
        supervised_keys=(),
        # Homepage of the dataset for documentation
        homepage='https://dataset-homepage/',
        citation=_CITATION,
    )

  def _split_generators(self, dl_manager):
    """Returns SplitGenerators."""
    # TODO(my_dataset): Downloads the data and defines the splits
    # dl_manager is a tfds.download.DownloadManager that can be used to
    # download and extract URLs
    path = dl_manager.download_and_extract('https://todo-data-url')

    return [
        tfds.core.SplitGenerator(
            name=tfds.Split.TRAIN,
            # These kwargs will be passed to _generate_examples
           gen_kwargs={},
        ),
    ]

  def _generate_examples(self):
    """Yields examples."""
    # TODO(my_dataset): Yields (key, example) tuples from the dataset
    yield 'key', {}

其中包含三个主要的函数:

  • _info(): 包含数据集的基本信息和每个数据样本的格式.
  • _split_generators(): 数据集的下载和划分.
  • _generate_examples(): 每个 split 的生成器, 每次生成一个样本

我们要做的就是把这个模板文件按照我们自己的数据集填写完毕. 填写的规则可以参考官方文档.

此外, 也可以进一步参考 /<site-packages>/tensorflow_datasets/image_classification/cifar.py 这类的例子.

这里我们要注意, 要使用本地的数据集, 我们需要把上面代码中下载数据的部分改为直接从本地读取即可:

1
2
3
4
# 把
path = dl_manager.download_and_extract('https://todo-data-url')
# 改为
path = "/path/to/your_dataset"

4.1 自定义数据集实例

下面给一个读取自定义数据集的一个例子. 该数据集为 PASCAL VOC 2012 数据集, 包含 20 个前景类别, 我们把包含第 1-5 类的图片作为测试数据, 把包含 6-20 类的图片作为训练数据, 打算使用自监督学习训练一个特征提取器. 因此我们需要创建两个只包含图像(不需要标签)的训练集和测试集, 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""voc_fs dataset."""

import collections
from pathlib import Path

import tensorflow.compat.v1 as tf
import tensorflow_datasets.public_api as tfds

_CITATION = """ """

_DESCRIPTION = """
PASCAL VOC 2012 datasets for self-supervised learning.
Four splits:
     1 -  5:  aeroplane  bicycle  bird    boat    bootle 
     6 - 10:     bus       car    cat    chair     cow 
    11 - 15: diningtable   dog   horse  motorbike person 
    16 - 20:    plant     sheep   sofa   train     tv
"""


class VocFS(tfds.core.GeneratorBasedBuilder):
    """ PASCAL VOC 2012 datasets for self-supervised learning  """

    VERSION = tfds.core.Version('0.1.0')

    MANUAL_DOWNLOAD_INSTRUCTIONS = "Manually download VOC_FS"

    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description=_DESCRIPTION,
            features=tfds.features.FeaturesDict({
                "id": tfds.features.Text(),
                "image": tfds.features.Image(shape=(None, None, 3), 
                                             encoding_format='jpeg'),
            }),
        )

    def _split_generators(self, dl_manager):
        """Returns SplitGenerators."""
        voc_fs_path = "/data/VOCdevkit/VOC2012" / self.name / str(self.VERSION)

        train_ids = []
        for i in range(6, 21):
            ids = (voc_fs_path / f"Binary_map_aug/train/{i}.txt")\
                .read_text().strip().splitlines()
            ids = [voc_fs_path / f"JPEGImages/{id_}.jpg" for id_ in ids]
            train_ids.extend(ids)
        train_ids = list(set(train_ids))
    
        test_ids = []
        for i in range(1, 6):
            ids = (voc_fs_path / f"Binary_map_aug/val/{i}.txt")\
                .read_text().strip().splitlines()
            ids = [voc_fs_path / f"JPEGImages/{id_}.jpg" for id_ in ids]
            test_ids.extend(ids)
        test_ids = list(set(test_ids))

        return [
            tfds.core.SplitGenerator(name='train', gen_kwargs={ "ids": train_ids }),
            tfds.core.SplitGenerator(name='test', gen_kwargs={ "ids": test_ids }),
        ]

    def _generate_examples(self, ids):
        """Yields examples."""
        for index, id_ in enumerate(ids): 
            yield index, { 'id': id_.stem, 'image': str(id_) }

使用 TFDS 加载数据集:

1
2
3
4
datasets = tfds.load('voc_fs', 
                     split='train', 
                     data_dir='/data/tfds', 
                     download=False)

其中第一次加载时 TFDS 会从自动生成 TFRecord 数据集并保存在 /data/tfds/voc_fs/0.1.0 目录下, 在第二次加载时就会直接读取了.


Content