def blockshaped(arr, nrows, ncols):
"""
arr: (N, H, W)
Return an array of shape (n, h//nrows, w//ncols, nrows, ncols) where
n * h//nrows * w//ncols * nrows * ncols = arr.size
If arr is a 2D array, the returned array should look like n subblocks with
each subblock preserving the "physical" layout of arr.
"""
n, h, w = arr.shape
assert h % nrows == 0, "{} rows is not evenly divisble by {}".format(h, nrows)
assert w % ncols == 0, "{} cols is not evenly divisble by {}".format(w, ncols)
# 원본
# return arr.reshape(h//nrows, nrows, -1, ncols).swapaxes(1,2).reshape(-1, nrows, ncols)
return ( arr.reshape(n, h//nrows, nrows, -1, ncols)
.transpose(0, 1, 3, 2, 4)
.reshape(-1, h//nrows, w//ncols, nrows, ncols) )
Each image batch size is N, and each width and hight of image is H and W.
We want to reshape and return it as N x H//nrows x W//ncols x nrows x ncols (shape: (n, h//nrows, w//ncols, nrows, ncols)
Why are we using reshape, transpose, and reshape ??
All 2D(and beyond) dimensions are saved as 1D memory-wise.
Ex.
arr = [[0,1,2,3],
[4,5,6,7],
[8,9,10,11],
[12,13,14,15]]
→In memory: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
reshape cannot change the order of a sequence. It can only divide the sequence.
→ For example, reshape(2,2,2,2) would become
[[[[0,1],
[2,3]],
[[4,5],
[6,7]]],
[[[ 8, 9],
[10,11]],
[[12,13],
[14,15]]]]
which is a row-wise division.
What we want is mini patches from the image, not a row-wise division. Therefore we need transpose.
transpose(0,2,1,3) can switch the internal order as
Before:
Block (0,0) = [[0,1],
[2,3]]
After:
Block (0,0) = [[0,1],
[4,5]]
Last line in the code .reshape(-1, h//nrows, w//ncols, nrows, ncols) doesn’t add any config. using assert is better.

class MnistAttn(nn.Module):
def __init__(self, image_size, patch_size, hidden_patch, n_class):
"""
이미지와 패치는 정사각형
image_size: 입력되는 쪼개지기 전의 원래 이미지 크기, 28
patch_size: 이미지를 패치로 쪼갤때 패치 하나의 크기, 7
hidden_patch: 각 패치에 적용되는 Linear층의 출력 개수, 16
n_class: 클래스 수, 10
"""
super(MnistAttn, self).__init__()
self.patch_size = patch_size
self.hidden_patch = hidden_patch
self.alpha = None # attention weights
assert image_size % patch_size == 0, \\
"{} is not evenly divisble by {}".format(image_size, patch_size)
self.n_patch = image_size // patch_size # 4
# 16 (49, 16)
# 16 independent Linear Layers (for each patch(7x7))
self.blocked_linears = torch.nn.ModuleList([
nn.Linear(patch_size**2, hidden_patch)
for i in range(self.n_patch**2) ])
# 16 (49, 1)
self.attns = torch.nn.ModuleList([
nn.Linear(patch_size**2, 1)
for i in range(self.n_patch**2) ])
# (256, 10)
self.hidden = nn.Linear(self.n_patch**2 * hidden_patch, n_class)
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, x):
"""
x : (N, n_patch_rows, n_patch_cols, patch_rows, patch_cols)
"""
# 주석은 미니배치 10, 이미지사이즈가 (28,28), 패치사이즈가 (7,7)
# 패치가 Linear()를 통과한후 나오는 출력이 16인 경우를 가정한다.
# 주석에서
# N: 배치 사이즈
# o: 패치를 입력받는 Linear의 출력 사이즈
# l: 패치 개수
N, _, _, Pr, Pc = x.shape
x = x.reshape(N, -1, Pr, Pc) # (N, 16, 7, 7)
x = x.reshape(N, self.n_patch**2, self.patch_size**2) # (N, 16, 49)
# linear_i(x[:, i]) == 모든 batch(N)의 i(0~15)번째 patch들
blocks = ( torch.cat([ nn.functional.relu(linear_i(x[:, i])) # (N, o) * l, (N, 16) * 16
# o(hidden patch size) is 16
for i, linear_i in enumerate(self.blocked_linears) ], dim=0) # 행 방향으로 더함
.reshape(-1, N, self.hidden_patch).transpose(1,0) ) # (N, l, o), (N, 16, 16)
attn_weights = torch.cat( [ torch.tanh(attn_i(x[:, i]))
for i, attn_i in enumerate(self.attns) ], dim=1)
attn_weights = nn.functional.softmax(attn_weights, dim=1) # (N, l), (N, 16)
self.alpha = attn_weights
attn_weights = torch.unsqueeze(attn_weights, 2) # (N, l, 1), (N, 16, 1)
x = blocks * attn_weights # (N,l,o)*(N,l,1)=(N, 16, 16)*(N, 16, 1)=(N, 16, 16)
# reshape (N, 16, 16) -> (N, 16*16)
x = x.reshape(N,-1)
x = self.hidden(x) # (N,10)
x = self.logsoftmax(x)
return x