pytorchのDataLoaderについて

kaggleでよく目にする「DataLoader」の役割がよくわからなかったので調べ直しました。 以下ページを参考に(マルパクリして)実装しました。

gotutiyan.hatenablog.com

DataSetとDataLoaderを用いることで、ミニバッチ化を簡単に実装できる!そうです。

DataSet:元々のデータを全て持っていて、ある番号を指定されるとその番号の入出力のペアを返す。クラスを使って実装。
DataLoader:DataSetのインスタンスを渡すことで、ミニバッチ化した後のデータを返す。元から用意されている関数を呼び出す。

DataSetの実装

class DataSet:
    def __init__(self):
        self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        self.t = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]

    def __len__(self.X)

    def __getitem__(self, index):

    return self.X[index], self,t[index]

以上がDataSetの実装。どのように振る舞うかというと

dataset = DataSet()
len(dataset) # 10
len(dataset[4]) # (4, 0)
len(dataset[2:5]) # ([2, 3, 4], [1, 0, 1])

上記の通りインデックスを指定するとデータとターゲットを返してくれる。

DataLoaderの実装

pytorchのモジュールに含まれるtorch.utils.data.DataLoader()を使う。

dataset = DataSet()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

for data in dataloader:
    print(data)

'''
出力:
[tensor([4, 1]), tensor([0, 1])]
[tensor([0, 7]), tensor([0, 1])]
[tensor([9, 3]), tensor([1, 1])]
[tensor([6, 5]), tensor([0, 1])]
[tensor([8, 2]), tensor([0, 0])]
'''

batch_sizeでバッチサイズを指定し、shuffle=Trueとすることで順番がバラバラになる。

まとめ

DataLoaderの意義と実装方法がなんとなく分かったので、ぜひ使っていきたいと思います。