MNIST

class Mnist(nn.Module):
    def __init__(self, image_size, patch_size, hidden_patch, n_class):
        """
        이미지와 패치는 정사각형
        image_size:   입력되는 쪼개지기 전의 원래 이미지 크기, 28
        patch_size:   이미지를 패치로 쪼갤때 패치 하나의 크기, 7
        hidden_patch: 각 패치에 적용되는 Linear층의 출력 개수, 16
        n_class:      클래스 수, 10
        """
        super(Mnist, self).__init__()

        self.patch_size = patch_size

        assert image_size % patch_size == 0, \\
            "{} is not evenly divisble by {}".format(image_size, patch_size)

        self.n_patch = image_size // patch_size # 28 // 7 = 4
        self.blocked_linears = torch.nn.ModuleList([
														        # patch_size**2 == 각 patch의 pixel
                                    nn.Linear(patch_size**2, hidden_patch)
                                    # 총 16개의 patch들에 대해 pixel들을 독립적으로
                                    # Linear Layer를 통과하게 함.
                                    for i in range(self.n_patch**2) ])
        self.hidden = nn.Linear(self.n_patch**2 * hidden_patch, n_class)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        """
        x : (N, n_patch_rows, n_patch_cols, patch_rows, patch_cols)
        """
        # 주석은 미니배치 10, 이미지사이즈가 (28,28), 패치사이즈가 (7,7)
        # 패치가 Linear()를 통과한후 나오는 출력이 16인 경우를 가정한다.
        # 주석에서
        # N: 배치 사이즈
        N, _, _, Pr, Pc = x.shape

        # forward는 분할된 이미지를 입력받는다. 이 경우 입력은 (10, 4, 4, 7, 7)
        # 4x4개 패치를 16개 한 차원으로 몰아넣는다.
        x = x.reshape(N, -1, Pr, Pc) # (N, 16, 7, 7)

        # (7,7)인 개별 패치를 한줄로 편다
        x = x.reshape(N, self.n_patch**2, self.patch_size**2) # (N, 16, 49),

        # ModuleList에 저장된 Linear() 열여섯개를 순차적으로 포워드시킨다.
        # 그 열여섯개 결과를 1번축으로 이어붙인다. -> 256
        blocks = torch.cat( [ linear_i(x[:, i])
                for i, linear_i in enumerate(self.blocked_linears) ], dim=1) # (N, 16*16)

        x = nn.functional.relu(blocks) # (N, 256)->(N, 256)
        x = self.hidden(x) # (N, 256)->(N,10)
        x = self.logsoftmax(x) # (N,10)->(N,10)

        return x

What exactly does ModuleList do??

nn.ModuleList 여러 개의 sub module을 list의 형태로 묶어놓는 container

MnistAttn

class MnistAttn(nn.Module):
    def __init__(self, image_size, patch_size, hidden_patch, n_class):
        """
        이미지와 패치는 정사각형
        image_size:   입력되는 쪼개지기 전의 원래 이미지 크기, 28
        patch_size:   이미지를 패치로 쪼갤때 패치 하나의 크기, 7
        hidden_patch: 각 패치에 적용되는 Linear층의 출력 개수, 16
        n_class:      클래스 수, 10
        """
        super(MnistAttn, self).__init__()

        self.patch_size = patch_size
        self.hidden_patch = hidden_patch
        self.alpha = None # attention weights

        assert image_size % patch_size == 0, \\
            "{} is not evenly divisble by {}".format(image_size, patch_size)

        self.n_patch = image_size // patch_size # 4
        # 16 (49, 16)
        self.blocked_linears = torch.nn.ModuleList([
                                    nn.Linear(patch_size**2, hidden_patch)
                                    for i in range(self.n_patch**2) ])
        # 16 (49, 1)
        self.attns = torch.nn.ModuleList([
                                    nn.Linear(patch_size**2, 1)
                                    for i in range(self.n_patch**2) ])
        # (256, 10)
        self.hidden = nn.Linear(self.n_patch**2 * hidden_patch, n_class)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        """
        x : (N, n_patch_rows, n_patch_cols, patch_rows, patch_cols)
        """
        # 주석은 미니배치 10, 이미지사이즈가 (28,28), 패치사이즈가 (7,7)
        # 패치가 Linear()를 통과한후 나오는 출력이 16인 경우를 가정한다.
        # 주석에서
        # N: 배치 사이즈
        # o: 패치를 입력받는 Linear의 출력 사이즈
        # l: 패치 개수
        N, _, _, Pr, Pc = x.shape

        x = x.reshape(N, -1, Pr, Pc) # (N, 16, 7, 7)
        x = x.reshape(N, self.n_patch**2, self.patch_size**2) # (N, 16, 49)
				
				# linear_i(x[:, i]) == 모든 batch(N)의 i(0~15)번째 patch들
        blocks = ( torch.cat([ nn.functional.relu(linear_i(x[:, i])) # (N, o) * l, (N, 16) * 16
                    for i, linear_i in enumerate(self.blocked_linears) ], dim=0) # 행 방향으로 더함
                    .reshape(-1, N, self.hidden_patch).transpose(1,0) ) # (N, l, o), (N, 16, 16)

        attn_weights = torch.cat( [ torch.tanh(attn_i(x[:, i]))
                for i, attn_i in enumerate(self.attns) ], dim=1)
        attn_weights = nn.functional.softmax(attn_weights, dim=1) # (N, l), (N, 16)
        self.alpha = attn_weights
        attn_weights = torch.unsqueeze(attn_weights, 2) # (N, l, 1), (N, 16, 1)

        x = blocks * attn_weights # (N,l,o)*(N,l,1)=(N, 16, 16)*(N, 16, 1)=(N, 16, 16)

        # reshape (N, 16, 16) -> (N, 16*16)
        x = x.reshape(N,-1)

        x = self.hidden(x) # (N,10)
        x = self.logsoftmax(x)

        return x