import torch
x = torch.arange(8*12).view(1,1,8,12)
m=x.unfold(2, 4, 4)
n = m.unfold(3, 4, 4)
输入
第一次切,切高度维度,但是切完做了转置 ,得到(1,1,2,12,4)
切宽度 得到 张量维度1 1 2 3 4 4 最后得到 维度表示:批次 通道 高度切块数 宽度切块数 高度 宽度
patches.contiguous().view(B, -1, 4*4)其中contiguous()不用管,不改变张量形状.view(B, -1, 4*4),把张量第一维度是 B,第二维度-1代表自动计算大小,其实就是总共切块的数目,最后一维度代表一个块的大小