class Mnist(nn.Module):
def __init__(self, image_size, patch_size, hidden_patch, n_class):
"""
이미지와 패치는 정사각형
image_size: 입력되는 쪼개지기 전의 원래 이미지 크기, 28
patch_size: 이미지를 패치로 쪼갤때 패치 하나의 크기, 7
hidden_patch: 각 패치에 적용되는 Linear층의 출력 개수, 16
n_class: 클래스 수, 10
"""
super(Mnist, self).__init__()
self.patch_size = patch_size
assert image_size % patch_size == 0, \\
"{} is not evenly divisble by {}".format(image_size, patch_size)
self.n_patch = image_size // patch_size # 28 // 7 = 4
self.blocked_linears = torch.nn.ModuleList([
# patch_size**2 == 각 patch의 pixel
nn.Linear(patch_size**2, hidden_patch)
# 총 16개의 patch들에 대해 pixel들을 독립적으로
# Linear Layer를 통과하게 함.
for i in range(self.n_patch**2) ])
self.hidden = nn.Linear(self.n_patch**2 * hidden_patch, n_class)
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, x):
"""
x : (N, n_patch_rows, n_patch_cols, patch_rows, patch_cols)
"""
# 주석은 미니배치 10, 이미지사이즈가 (28,28), 패치사이즈가 (7,7)
# 패치가 Linear()를 통과한후 나오는 출력이 16인 경우를 가정한다.
# 주석에서
# N: 배치 사이즈
N, _, _, Pr, Pc = x.shape
# forward는 분할된 이미지를 입력받는다. 이 경우 입력은 (10, 4, 4, 7, 7)
# 4x4개 패치를 16개 한 차원으로 몰아넣는다.
x = x.reshape(N, -1, Pr, Pc) # (N, 16, 7, 7)
# (7,7)인 개별 패치를 한줄로 편다
x = x.reshape(N, self.n_patch**2, self.patch_size**2) # (N, 16, 49),
# ModuleList에 저장된 Linear() 열여섯개를 순차적으로 포워드시킨다.
# 그 열여섯개 결과를 1번축으로 이어붙인다. -> 256
blocks = torch.cat( [ linear_i(x[:, i])
for i, linear_i in enumerate(self.blocked_linears) ], dim=1) # (N, 16*16)
x = nn.functional.relu(blocks) # (N, 256)->(N, 256)
x = self.hidden(x) # (N, 256)->(N,10)
x = self.logsoftmax(x) # (N,10)->(N,10)
return x
What exactly does ModuleList do??
nn.ModuleList 여러 개의 sub module을 list의 형태로 묶어놓는 container
class MnistAttn(nn.Module):
def __init__(self, image_size, patch_size, hidden_patch, n_class):
"""
이미지와 패치는 정사각형
image_size: 입력되는 쪼개지기 전의 원래 이미지 크기, 28
patch_size: 이미지를 패치로 쪼갤때 패치 하나의 크기, 7
hidden_patch: 각 패치에 적용되는 Linear층의 출력 개수, 16
n_class: 클래스 수, 10
"""
super(MnistAttn, self).__init__()
self.patch_size = patch_size
self.hidden_patch = hidden_patch
self.alpha = None # attention weights
assert image_size % patch_size == 0, \\
"{} is not evenly divisble by {}".format(image_size, patch_size)
self.n_patch = image_size // patch_size # 4
# 16 (49, 16)
self.blocked_linears = torch.nn.ModuleList([
nn.Linear(patch_size**2, hidden_patch)
for i in range(self.n_patch**2) ])
# 16 (49, 1)
self.attns = torch.nn.ModuleList([
nn.Linear(patch_size**2, 1)
for i in range(self.n_patch**2) ])
# (256, 10)
self.hidden = nn.Linear(self.n_patch**2 * hidden_patch, n_class)
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, x):
"""
x : (N, n_patch_rows, n_patch_cols, patch_rows, patch_cols)
"""
# 주석은 미니배치 10, 이미지사이즈가 (28,28), 패치사이즈가 (7,7)
# 패치가 Linear()를 통과한후 나오는 출력이 16인 경우를 가정한다.
# 주석에서
# N: 배치 사이즈
# o: 패치를 입력받는 Linear의 출력 사이즈
# l: 패치 개수
N, _, _, Pr, Pc = x.shape
x = x.reshape(N, -1, Pr, Pc) # (N, 16, 7, 7)
x = x.reshape(N, self.n_patch**2, self.patch_size**2) # (N, 16, 49)
# linear_i(x[:, i]) == 모든 batch(N)의 i(0~15)번째 patch들
blocks = ( torch.cat([ nn.functional.relu(linear_i(x[:, i])) # (N, o) * l, (N, 16) * 16
for i, linear_i in enumerate(self.blocked_linears) ], dim=0) # 행 방향으로 더함
.reshape(-1, N, self.hidden_patch).transpose(1,0) ) # (N, l, o), (N, 16, 16)
attn_weights = torch.cat( [ torch.tanh(attn_i(x[:, i]))
for i, attn_i in enumerate(self.attns) ], dim=1)
attn_weights = nn.functional.softmax(attn_weights, dim=1) # (N, l), (N, 16)
self.alpha = attn_weights
attn_weights = torch.unsqueeze(attn_weights, 2) # (N, l, 1), (N, 16, 1)
x = blocks * attn_weights # (N,l,o)*(N,l,1)=(N, 16, 16)*(N, 16, 1)=(N, 16, 16)
# reshape (N, 16, 16) -> (N, 16*16)
x = x.reshape(N,-1)
x = self.hidden(x) # (N,10)
x = self.logsoftmax(x)
return x