TTAの実装(PyTorch)

  1. CFGでTTA定義
  2. datasetでもself.TTA定義
  3. train roopの際に画像をそれぞれ受け取ってGPUに乗せてmodelに入れて平均をとる
class TestData(Dataset):
    def __init__(self, TTA=False):
        self.TTA = TTA

    def __getitem__(self,index):
        if self.TTA:
            imgs =  ...
            imgs2 = imgs.flip(2,3) # 縦横反転(2が縦、3が横)
            imgs3 = torch.rot90(imgs,1,[2,3]) #2,3次元の軸で左に90度回転
            imgs4 = torch.rot90(imgs,-1,[2,3]) #2,3次元の軸で右に90度回転
            return imgs, imgs2, imgs3, imgs3
        else:
            imgs = self.cell_tiles[index] #リスト配列のnumpy画像(3次元)を取得
            imgs = [self.transform()(image=img)['image'] for img in imgs]
            imgs = torch.stack(imgs, 0)
            return imgs


        # 以下、train roopでの実装
        if CFG.TTA:
            testset = TestData(batch_cell_tiles, test_transform, TTA=True)
            testloader = DataLoader(testset,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=4,
                                     collate_fn=collate_fn)

            with torch.no_grad():
                model.eval()
                for images, images2, images3, images4 in testloader: # images:all cell image of one test image
                    images = images.to(DEVICE)
                    images2 = images2.to(DEVICE)
                    images3 = images3.to(DEVICE)
                    images4 = images4.to(DEVICE)
                    
                    output = model(images)
                    output = nn.Sigmoid()(output)
                    
                    output2 = model(images2)
                    output2 = nn.Sigmoid()(output2)
                    
                    output3 = model(images3)
                    output3 = nn.Sigmoid()(output3)
                    
                    output4 = model(images4)
                    output4 = nn.Sigmoid()(output4)
                    
                    output = (output+output2+output3+output4)/4