pytorchのDataLoaderについて
kaggleでよく目にする「DataLoader」の役割がよくわからなかったので調べ直しました。 以下ページを参考に(マルパクリして)実装しました。
DataSetとDataLoaderを用いることで、ミニバッチ化を簡単に実装できる!そうです。
DataSet:元々のデータを全て持っていて、ある番号を指定されるとその番号の入出力のペアを返す。クラスを使って実装。
DataLoader:DataSetのインスタンスを渡すことで、ミニバッチ化した後のデータを返す。元から用意されている関数を呼び出す。
DataSetの実装
class DataSet: def __init__(self): self.X = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] self.t = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] def __len__(self.X) def __getitem__(self, index): return self.X[index], self,t[index]
以上がDataSetの実装。どのように振る舞うかというと
dataset = DataSet() len(dataset) # 10 len(dataset[4]) # (4, 0) len(dataset[2:5]) # ([2, 3, 4], [1, 0, 1])
上記の通りインデックスを指定するとデータとターゲットを返してくれる。
DataLoaderの実装
pytorchのモジュールに含まれるtorch.utils.data.DataLoader()を使う。
dataset = DataSet() dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) for data in dataloader: print(data) ''' 出力: [tensor([4, 1]), tensor([0, 1])] [tensor([0, 7]), tensor([0, 1])] [tensor([9, 3]), tensor([1, 1])] [tensor([6, 5]), tensor([0, 1])] [tensor([8, 2]), tensor([0, 0])] '''
batch_sizeでバッチサイズを指定し、shuffle=Trueとすることで順番がバラバラになる。
まとめ
DataLoaderの意義と実装方法がなんとなく分かったので、ぜひ使っていきたいと思います。
Microsoft Malware Prediction
Kaggleのコンペの1つ「Microsoft Malware Prediction」のカーネルまとめ
CVとLBの点数に隔たりがある理由について解説しているカーネル。
原因は、TRAINデータとTESTデータでそれぞれ別の分布から引っ張ってきているから。具体的には、TRAINデータは2018年の8,9月のデータであるのに対し、Public TESTは10月、Private TESTは11月のデータ。検証のためにそれぞれのデータの分布を可視化し、TRAINデータを時系列Validationした場合の分布がランダムでValidationをした時に比べてTESTデータの分布に近いことを示している。
モデル1(ランダムValidation)とモデル2(時系列Validation)について、Adversarial Validationの手法を使ってそれぞれを区別できるかどうかを検証。モデル1では区別できる(テストデータと違いがある)が、モデル2ではできないことを実証。
最後に、TrainとPublic Testで違いが大きい説明変数を挙げ、より正確なValidationモデルを作るために分布がTestデータからずれているデータを、Testデータと分布が近い新しい値に変換する必要があると述べている。
TrainデータとTestデータがどちらもサイズが大きい。これを解決する方法を解説するカーネル。
データがアップスケールされて無駄にスペース(メモリ)を消費していないかチェックする。チェックする際にはすべてのデータを読み込む必要はなく、pd.read_csvのchunksizeを使うとよい(参考:pandas でメモリに乗らない 大容量ファイルを上手に扱う - StatsFragments)。
データ読み込みのコツは、 ①オブジェクトデータをカテゴリ変数として読み込む ②バイナリはint8指定 ③欠損値があるバイナリはfloat16に変換(intはnanを読み取れないため) ④64bitのエンコーディングはなるべく32,16bitへ変換する。 これらはread_csvのdatatypesを指定することで実現できる。