pytorchのDataLoaderについて
kaggleでよく目にする「DataLoader」の役割がよくわからなかったので調べ直しました。 以下ページを参考に(マルパクリして)実装しました。
DataSetとDataLoaderを用いることで、ミニバッチ化を簡単に実装できる!そうです。
DataSet:元々のデータを全て持っていて、ある番号を指定されるとその番号の入出力のペアを返す。クラスを使って実装。
DataLoader:DataSetのインスタンスを渡すことで、ミニバッチ化した後のデータを返す。元から用意されている関数を呼び出す。
DataSetの実装
class DataSet: def __init__(self): self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] self.t = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] def __len__(self.X) def __getitem__(self, index): return self.X[index], self,t[index]
以上がDataSetの実装。どのように振る舞うかというと
dataset = DataSet() len(dataset) # 10 len(dataset[4]) # (4, 0) len(dataset[2:5]) # ([2, 3, 4], [1, 0, 1])
上記の通りインデックスを指定するとデータとターゲットを返してくれる。
DataLoaderの実装
pytorchのモジュールに含まれるtorch.utils.data.DataLoader()を使う。
dataset = DataSet() dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) for data in dataloader: print(data) ''' 出力: [tensor([4, 1]), tensor([0, 1])] [tensor([0, 7]), tensor([0, 1])] [tensor([9, 3]), tensor([1, 1])] [tensor([6, 5]), tensor([0, 1])] [tensor([8, 2]), tensor([0, 0])] '''
batch_sizeでバッチサイズを指定し、shuffle=Trueとすることで順番がバラバラになる。
まとめ
DataLoaderの意義と実装方法がなんとなく分かったので、ぜひ使っていきたいと思います。