- reshape()
- squeeze()
- unsqueeze()
- transpose()
- permute()
- view() == reshape()
- contiguous() == reshape()
一、reshape() 函数
保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状。
def reshape_tensor():data = torch.tensor([[1, 2, 3], [4, 5, 6]])print(data, data.shape) # pytorch中shape=size() 都可以获得张量的形状data1=data.reshape(1, 6) #reshape行列相乘要等于数据data的总个数# data1=data.reshape(6) # 等于上一行print(data1, data1.shape)data2=data.reshape(3,-1) #-1 自动推断,已知三行自动推断列数# data2=data.reshape(-1) #-1 自动推断print(data2, data2.shape)if __name__ == '__main__':reshape_tensor()
扩展
pytorch中 shape=size() 都可以获得张量的形状
# 扩展:size,和shape是等价的,都是看数据的维度 # print("data1-->", data1.shape, data1.size()) # shape[0] <=> size()[0] <=> size(0) # shape[1] <=> size()[1] <=> size(1)print("data1-->", data, data.shape[0], data.size(1))
二、squeeze() 和 unsqueeze()
squeeze 函数删除 形状为 1 的维度(升维),unsqueeze 函数添加形状为1的维度(降维)。
# 生维与降维
def unsqueeze_squeeze_tensor():# 准备数据data = torch.tensor([1, 2, 3, 4, 5])print('data-->', data, data.shape)# 升维: unsqueeze(), 增加一个维度,这个维度的长度为1# data1 = data.unsqueeze(dim=0) # [1, 5]# data1 = data.unsqueeze(dim=1) # [5, 1]data1 = data.unsqueeze(dim=-1).unsqueeze(dim=0) # [1, 5, 1]# data1 = data.unsqueeze(dim=2) # [5, 1] 会报错,越界print("data1-->", data1, data1.shape)# 降维: squeeze(), 能够减少维度为1的维度# 所有长度为1的维度都会降低。data2 = data1.squeeze()# print("data2-->", data2, data2.shape)
if __name__ == '__main__':unsqueeze_squeeze_tensor()
三、transpose() 和 permute()
transpose 函数可以实现交换张量形状的指定维度, 例如: 一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4 进行交换, 将张量的形状变为 (2, 4, 3)
permute 函数可以一次交换更多的维度。
def transpose_permute_tensor():# 生成随机张量,并设置随机种子,保持随机张量是固定值torch.manual_seed(0)data = torch.randint(0, 10, (3, 4, 5))print(data, data.shape)# .transpose指定交换的两个维度data1=data.transpose(1, 2) # torch.Size([3, 5, 4])# data1=data.transpose(0, 2)print(data1, data1.shape) # torch.Size([5, 4, 3])# .permute指定交换的多个维度data2 = data.permute(2,0,1)print(data2, data2.shape) # torch.Size([5, 3, 4])
if __name__ == '__main__':transpose_permute_tensor()
五、view() 和 contiguous()
view 函数也可以用于修改张量的形状,但是其用法比较局限,只能用于存储在整块内存中的张量。在 PyTorch 中,有些张量是由不同的数据块组成的,它们并没有存储在整块的内存中,view 函数无法对这样的张量进行变形处理,例如: 一个张量经过了 transpose 函数的处理之后,就无法使用 view 函数进行形状操作。
view 函数也可以用于修改张量的形状, 但是它要求被转换的张量内存必须连续,所以一般配合 contiguous 函数使用。
def view_contiguous_tensor():torch.manual_seed(0)data = torch.randint(0, 10, (3, 4, 5))print(data, data.shape)data = data.transpose(1, 2) # 不连续 torch.Size([3, 5, 4])# data = data.permute(1, 2, 0)# data = data.view(1, 2, -1) # data不连续后,调用view函数会报错print(data, data.shape)print(data.is_contiguous()) # 判断是否连续# print(data.contiguous().is_contiguous()) # 通过contiguous把不连续的内存空间变成连续print(data.contiguous().view(3,4,5)) # 再view()
if __name__ == '__main__':view_contiguous_tensor()
六、小结
- reshape 函数可以在保证张量数据不变的前提下改变数据的维度
- squeeze 和 unsqueeze 函数可以用来增加或者减少维度
- transpose 函数可以实现交换张量形状的指定维度, permute 可以一次交换更多的维度
- view 函数也可以用于修改张量的形状, 但是它要求被转换的张量内存必须连续, 所以一般配合 contiguous 函数使用