pytorchのDataLoaderについて

kaggleでよく目にする「DataLoader」の役割がよくわからなかったので調べ直しました。 以下ページを参考に(マルパクリして)実装しました。

gotutiyan.hatenablog.com

DataSetとDataLoaderを用いることで、ミニバッチ化を簡単に実装できる!そうです。

DataSet:元々のデータを全て持っていて、ある番号を指定されるとその番号の入出力のペアを返す。クラスを使って実装。
DataLoader:DataSetのインスタンスを渡すことで、ミニバッチ化した後のデータを返す。元から用意されている関数を呼び出す。

DataSetの実装

class DataSet:
    def __init__(self):
        self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        self.t = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]

    def __len__(self.X)

    def __getitem__(self, index):

    return self.X[index], self,t[index]

以上がDataSetの実装。どのように振る舞うかというと

dataset = DataSet()
len(dataset) # 10
len(dataset[4]) # (4, 0)
len(dataset[2:5]) # ([2, 3, 4], [1, 0, 1])

上記の通りインデックスを指定するとデータとターゲットを返してくれる。

DataLoaderの実装

pytorchのモジュールに含まれるtorch.utils.data.DataLoader()を使う。

dataset = DataSet()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

for data in dataloader:
    print(data)

'''
出力:
[tensor([4, 1]), tensor([0, 1])]
[tensor([0, 7]), tensor([0, 1])]
[tensor([9, 3]), tensor([1, 1])]
[tensor([6, 5]), tensor([0, 1])]
[tensor([8, 2]), tensor([0, 0])]
'''

batch_sizeでバッチサイズを指定し、shuffle=Trueとすることで順番がバラバラになる。

まとめ

DataLoaderの意義と実装方法がなんとなく分かったので、ぜひ使っていきたいと思います。

Microsoft Malware Prediction

Kaggleのコンペの1つ「Microsoft Malware Prediction」のカーネルまとめ


www.kaggle.com

  • CVとLBの点数に隔たりがある理由について解説しているカーネル

  • 原因は、TRAINデータとTESTデータでそれぞれ別の分布から引っ張ってきているから。具体的には、TRAINデータは2018年の8,9月のデータであるのに対し、Public TESTは10月、Private TESTは11月のデータ。検証のためにそれぞれのデータの分布を可視化し、TRAINデータを時系列Validationした場合の分布がランダムでValidationをした時に比べてTESTデータの分布に近いことを示している。

  • モデル1(ランダムValidation)とモデル2(時系列Validation)について、Adversarial Validationの手法を使ってそれぞれを区別できるかどうかを検証。モデル1では区別できる(テストデータと違いがある)が、モデル2ではできないことを実証。

  • 最後に、TrainとPublic Testで違いが大きい説明変数を挙げ、より正確なValidationモデルを作るために分布がTestデータからずれているデータを、Testデータと分布が近い新しい値に変換する必要があると述べている。



www.kaggle.com

  • TrainデータとTestデータがどちらもサイズが大きい。これを解決する方法を解説するカーネル

  • データがアップスケールされて無駄にスペース(メモリ)を消費していないかチェックする。チェックする際にはすべてのデータを読み込む必要はなく、pd.read_csvのchunksizeを使うとよい(参考:pandas でメモリに乗らない 大容量ファイルを上手に扱う - StatsFragments)。

  • データ読み込みのコツは、  ①オブジェクトデータをカテゴリ変数として読み込む  ②バイナリはint8指定  ③欠損値があるバイナリはfloat16に変換(intはnanを読み取れないため)  ④64bitのエンコーディングはなるべく32,16bitへ変換する。  これらはread_csvのdatatypesを指定することで実現できる。